Keras 3 API 文档 / 层 API / 循环层 / GRU单元层

GRU单元层

[source]

GRUCell class

keras.layers.GRUCell(
    units,
    activation="tanh",
    recurrent_activation="sigmoid",
    use_bias=True,
    kernel_initializer="glorot_uniform",
    recurrent_initializer="orthogonal",
    bias_initializer="zeros",
    kernel_regularizer=None,
    recurrent_regularizer=None,
    bias_regularizer=None,
    kernel_constraint=None,
    recurrent_constraint=None,
    bias_constraint=None,
    dropout=0.0,
    recurrent_dropout=0.0,
    reset_after=True,
    seed=None,
    **kwargs
)

GRU层的单元类.

该类处理整个时间序列输入中的一个步骤,而keras.layer.GRU处理整个序列.

参数: units: 正整数,输出空间的维度. activation: 要使用的激活函数.默认:双曲正切(tanh).如果你传递None,则不应用激活函数(即"线性”激活:a(x) = x). recurrent_activation: 用于循环步骤的激活函数.默认:sigmoid(sigmoid).如果你传递None,则不应用激活函数(即"线性”激活:a(x) = x). use_bias: 布尔值,(默认True),该层是否应使用偏置向量. kernel_initializer: kernel权重矩阵的初始化器,用于输入的线性变换.默认:"glorot_uniform". recurrent_initializer: recurrent_kernel权重矩阵的初始化器,用于循环状态的线性变换.默认:"orthogonal". bias_initializer: 偏置向量的初始化器.默认:"zeros". kernel_regularizer: 应用于kernel权重矩阵的正则化函数.默认:None. recurrent_regularizer: 应用于recurrent_kernel权重矩阵的正则化函数.默认:None. bias_regularizer: 应用于偏置向量的正则化函数.默认:None. kernel_constraint: 应用于kernel权重矩阵的约束函数.默认:None. recurrent_constraint: 应用于recurrent_kernel权重矩阵的约束函数.默认:None. bias_constraint: 应用于偏置向量的约束函数.默认:None. dropout: 0到1之间的浮点数.用于输入线性变换的单元丢弃比例.默认:0. recurrent_dropout: 0到1之间的浮点数.用于循环状态线性变换的单元丢弃比例.默认:0. reset_after: GRU约定(是否在矩阵乘法之后或之前应用重置门).False = "before",True = "after"(默认且与cuDNN兼容). seed: 用于丢弃的随机种子.

调用参数: inputs: 一个2D张量,形状为(batch, features). states: 一个2D张量,形状为(batch, units),即从前一个时间步的状态. training: Python布尔值,指示层应在训练模式还是推理模式下运行.仅在使用dropoutrecurrent_dropout时相关.

示例:

>>> inputs = np.random.random((32, 10, 8))
>>> rnn = keras.layers.RNN(keras.layers.GRUCell(4))
>>> output = rnn(inputs)
>>> output.shape
(32, 4)
>>> rnn = keras.layers.RNN(
...    keras.layers.GRUCell(4),
...    return_sequences=True,
...    return_state=True)
>>> whole_sequence_output, final_state = rnn(inputs)
>>> whole_sequence_output.shape
(32, 10, 4)
>>> final_state.shape
(32, 4)