torch.ao.nn.quantizable.modules.rnn 的源代码
```html
import numbers from typing import Optional, Tuple import warnings import torch from torch import Tensor """ 我们将重新创建所有的RNN模块,因为我们要求模块被分解 成其构建块以便能够观察。 """ __all__ = [ "LSTMCell", "LSTM" ] class LSTMCell(torch.nn.Module): r"""一个可量化的长短期记忆(LSTM)单元。 有关描述和参数类型,请参阅 :class:`~torch.nn.LSTMCell` 示例:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTMCell(10, 20) >>> input = torch.randn(6, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """ _FLOAT_MODULE = torch.nn.LSTMCell def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.input_size = input_dim self.hidden_size = hidden_dim self.bias = bias self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.gates = torch.ao.nn.quantized.FloatFunctional() self.input_gate = torch.nn.Sigmoid() self.forget_gate = torch.nn.Sigmoid() self.cell_gate = torch.nn.Tanh() self.output_gate = torch.nn.Sigmoid() self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) self.hidden_state_dtype: torch.dtype = torch.quint8 self.cell_state_dtype: torch.dtype = torch.quint8 def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hx, cx = hidden igates = self.igates(x) hgates = self.hgates(hx) gates = self.gates.add(igates, hg