ray.rllib.models.modelv2.ModelV2.forward#

ModelV2.forward(input_dict: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], state: List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], seq_lens: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor)[源代码]#

使用给定的输入张量和状态调用模型。

任何复杂的观察(字典、元组等)将在传递给 forward() 之前由 __call__ 解包。要访问展平的观察张量,请参考 input_dict[“obs_flat”]。

这个方法可以被调用任意次数。在急切执行中,每次调用 forward() 都会急切地评估模型。在符号执行中,每次调用 forward 都会创建一个计算图,该图在这个模型的变量上操作(即共享权重)。

自定义模型应重写此方法而不是 __call__。

参数:
  • input_dict – 输入张量的字典,包括 “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, 和 “t”。

  • state – 状态张量列表,其大小与 get_initial_state 返回的大小 + 批量维度匹配

  • seq_lens – 1d 张量,持有输入序列的长度

返回:

一个由模型输出张量(大小为 [BATCH, num_outputs])和新的 RNN 状态列表(如果有)组成的元组。

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
class MyModel(ModelV2):
    # ...
    def forward(self, input_dict, state, seq_lens):
        model_out, self._value_out = self.base_model(
            input_dict["obs"])
        return model_out, state