torch.ao.nn.quantized.modules.linear 的源代码
from collections.abc import Iterable
import torch
import torch.nn as nn
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Optional
from .utils import _quantize_weight, _hide_packed_params_repr, WeightedQuantizedModule
__all__ = ['LinearPackedParams', 'Linear']
class LinearPackedParams(torch.nn.Module):
_version = 3
def __init__(self, dtype=torch.qint8):
super().__init__()
self.dtype = dtype
if self.dtype == torch.qint8:
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined]
@torch.jit.export
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16:
self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
@torch.jit.export
def _weight_bias(self):
if self.dtype == torch.qint8:
return torch.ops.quantized.linear_unpack(self._packed_params)
elif self.dtype == torch.float16:
return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
def forward(self, x):
return x
# 版本 1
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# 版本 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- dtype : torch.dtype
#
# 版本 3
# self
# |--- _packed_params : (Tensor, Tensor) 表示 (weight, bias)
# of LinearPackedParams
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'dtype'] = self.dtype
destination[prefix + '_packed_params'] = self._weight_bias()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
self.dtype = torch.qint8
else:
self.dtype = state_dict[prefix + 'dtype']
<span