torch.ao.nn.quantized.dynamic.modules.linear 的源代码
import torch
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
import torch.ao.nn.intrinsic as nni
__all__ = [
"Linear",
]
[docs]class Linear(nnq.Linear):
r"""
一个动态量化的线性模块,输入和输出为浮点张量。
我们采用与 `torch.nn.Linear` 相同的接口,请参阅
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear 获取文档。
类似于 :class:`torch.nn.Linear`,属性将在模块创建时随机初始化,并将在之后被覆盖。
属性:
weight (Tensor): 模块的不可学习的量化权重,形状为 :math:`(\text{out\_features}, \text{in\_features})`。
bias (Tensor): 模块的不可学习的浮点偏置,形状为 :math:`(\text{out\_features})`。如果 :attr:`bias` 为 ``True``,
则值初始化为零。
示例::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
# 此类的版本与父类 nnq.Linear 不同
_version = 4
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias_, dtype=dtype)
# 我们不在这里处理缓冲区或属性等,以保持模块的简单性。*一切*都只是 Python 属性。
# 序列化逻辑在下面的序列化和反序列化模块中显式处理
self.version = 4
def forward(self, x):
# 注意我们可以处理 self.bias == None 的情况。
if self._packed_params.dtype == torch.qint8:
if self.version is None or self.version < 4:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params)
else:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params, reduce_range=True)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params)
else:
raise RuntimeError('不支持的动态量化线性数据类型!')
return Y.to(x.dtype)
def _get_name(self):
return 'DynamicQuantizedLinear'
def extra_repr(self):
extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
self.in_features, self.out_features, self._packed_params.dtype
)
if self._packed_params.dtype == torch.qint8:
extra_repr_str += f', qscheme={self.weight().qscheme()}'
return extra_repr_str
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
self.version = version
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
[docs] @classmethod
def from_float(cls, mod):
r"""从浮点模块或 qparams_dict 创建一个动态量化模块
参数:
mod (Module): 一个浮点模块,由 torch.ao.quantization 工具生成或由用户提供
"""
float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.ao.nn.intrinsic.modules.fused.LinearReLU, torch<span