ray.rllib.models.torch.torch_modelv2.TorchModelV2.get_initial_state#
- TorchModelV2.get_initial_state() List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] #
获取模型的初始递归状态值。
- 返回:
包含 RNN 初始隐藏状态的 np.array(对于 tf)或 Tensor(对于 torch)对象列表(如果适用)。
import numpy as np from ray.rllib.models.modelv2 import ModelV2 class MyModel(ModelV2): # ... def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ]