Shortcuts

torch.ao.nn.quantized.dynamic.modules.linear 的源代码

import torch
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
import torch.ao.nn.intrinsic as nni

__all__ = [
    "Linear",
]


[docs]class Linear(nnq.Linear): r""" 一个动态量化的线性模块,输入和输出为浮点张量。 我们采用与 `torch.nn.Linear` 相同的接口,请参阅 https://pytorch.org/docs/stable/nn.html#torch.nn.Linear 获取文档。 类似于 :class:`torch.nn.Linear`,属性将在模块创建时随机初始化,并将在之后被覆盖。 属性: weight (Tensor): 模块的不可学习的量化权重,形状为 :math:`(\text{out\_features}, \text{in\_features})`。 bias (Tensor): 模块的不可学习的浮点偏置,形状为 :math:`(\text{out\_features})`。如果 :attr:`bias` 为 ``True``, 则值初始化为零。 示例:: >>> # xdoctest: +SKIP >>> m = nn.quantized.dynamic.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """ # 此类的版本与父类 nnq.Linear 不同 _version = 4 def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): super().__init__(in_features, out_features, bias_, dtype=dtype) # 我们不在这里处理缓冲区或属性等,以保持模块的简单性。*一切*都只是 Python 属性。 # 序列化逻辑在下面的序列化和反序列化模块中显式处理 self.version = 4 def forward(self, x): # 注意我们可以处理 self.bias == None 的情况。 if self._packed_params.dtype == torch.qint8: if self.version is None or self.version < 4: Y = torch.ops.quantized.linear_dynamic( x, self._packed_params._packed_params) else: Y = torch.ops.quantized.linear_dynamic( x, self._packed_params._packed_params, reduce_range=True) elif self._packed_params.dtype == torch.float16: Y = torch.ops.quantized.linear_dynamic_fp16( x, self._packed_params._packed_params) else: raise RuntimeError('不支持的动态量化线性数据类型!') return Y.to(x.dtype) def _get_name(self): return 'DynamicQuantizedLinear' def extra_repr(self): extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format( self.in_features, self.out_features, self._packed_params.dtype ) if self._packed_params.dtype == torch.qint8: extra_repr_str += f', qscheme={self.weight().qscheme()}' return extra_repr_str def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) self.version = version super()._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs)
[docs] @classmethod def from_float(cls, mod): r"""从浮点模块或 qparams_dict 创建一个动态量化模块 参数: mod (Module): 一个浮点模块,由 torch.ao.quantization 工具生成或由用户提供 """ float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.ao.nn.intrinsic.modules.fused.LinearReLU, torch<span
优云智算