Shortcuts

torch.ao.nn.quantizable.modules.rnn 的源代码

```html
import numbers
from typing import Optional, Tuple
import warnings

import torch
from torch import Tensor

"""
我们将重新创建所有的RNN模块,因为我们要求模块被分解
成其构建块以便能够观察。
"""

__all__ = [
    "LSTMCell",
    "LSTM"
]

class LSTMCell(torch.nn.Module):
    r"""一个可量化的长短期记忆(LSTM)单元。

    有关描述和参数类型,请参阅 :class:`~torch.nn.LSTMCell`

    示例::

        >>> import torch.ao.nn.quantizable as nnqa
        >>> rnn = nnqa.LSTMCell(10, 20)
        >>> input = torch.randn(6, 10)
        >>> hx = torch.randn(3, 20)
        >>> cx = torch.randn(3, 20)
        >>> output = []
        >>> for i in range(6):
        ...     hx, cx = rnn(input[i], (hx, cx))
        ...     output.append(hx)
    """
    _FLOAT_MODULE = torch.nn.LSTMCell

    def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.input_size = input_dim
        self.hidden_size = hidden_dim
        self.bias = bias

        self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
        self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
        self.gates = torch.ao.nn.quantized.FloatFunctional()

        self.input_gate = torch.nn.Sigmoid()
        self.forget_gate = torch.nn.Sigmoid()
        self.cell_gate = torch.nn.Tanh()
        self.output_gate = torch.nn.Sigmoid()

        self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
        self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
        self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()

        self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()

        self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
        self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
        self.hidden_state_dtype: torch.dtype = torch.quint8
        self.cell_state_dtype: torch.dtype = torch.quint8

    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
        if hidden is None or hidden[0] is None or hidden[1] is None:
            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
        hx, cx = hidden

        igates = self.igates(x)
        hgates = self.hgates(hx)
        gates = self.gates.add(igates, hg
优云智算