torch.ao.nn.intrinsic.quantized.modules.conv_relu 的源代码
```html
import torch import torch.ao.nn.intrinsic import torch.ao.nn.intrinsic.qat import torch.nn.functional as F import torch.ao.nn.quantized as nnq from torch.nn.utils import fuse_conv_bn_weights __all__ = [ "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", ] _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding # TODO: 将公共部分提取到ConvNd[docs]class ConvReLU1d(nnq.Conv1d): r""" ConvReLU1d模块是Conv1d和ReLU的融合模块 我们采用与 :class:`torch.ao.nn.quantized.Conv1d` 相同的接口。 属性: 与 torch.ao.nn.quantized.Conv1d 相同 """ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) def forward(self, input): # 暂时使用len(shape)代替ndim,因为JIT问题 # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 3: raise ValueError("输入形状必须是`(N, C, L)`!") if self.padding_mode != 'zeros': # Conv1d中的padding存储为(p, p),需要获取(p,) _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) input = F.pad(input, _reversed_padding_repeated_twice, mode=self.padding_mode) return torch.ops.quantized.conv1d_relu( input, self._packed_params, self.scale, self.zero_point) def _get_name(self): return 'QuantizedConvReLU1d' @classmethod def from_float(cls, mod): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) return super().from_float(mod) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, \ "BatchNorm1d应在转换为参考模块之前融合到Conv1d中" return super().from_reference(ref_qconv[0], output_scale, output_zero_point)[docs]class ConvReLU2d(nnq.Conv2d): r""" ConvReLU2d模块是Conv2d和ReLU的融合模块 我们采用与 :class:`torch.ao.nn.quantized.Conv2d` 相同的接口。 属性: 与 torch.ao.nn