ray.rllib.models.torch.torch_modelv2.TorchModelV2.get_initial_state#

TorchModelV2.get_initial_state() List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]#

获取模型的初始递归状态值。

返回:

包含 RNN 初始隐藏状态的 np.array(对于 tf)或 Tensor(对于 torch)对象列表(如果适用)。

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
class MyModel(ModelV2):
    # ...
    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
        ]