ray.rllib.models.modelv2.ModelV2.forward#
- ModelV2.forward(input_dict: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], state: List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], seq_lens: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor)[源代码]#
使用给定的输入张量和状态调用模型。
任何复杂的观察(字典、元组等)将在传递给 forward() 之前由 __call__ 解包。要访问展平的观察张量,请参考 input_dict[“obs_flat”]。
这个方法可以被调用任意次数。在急切执行中,每次调用 forward() 都会急切地评估模型。在符号执行中,每次调用 forward 都会创建一个计算图,该图在这个模型的变量上操作(即共享权重)。
自定义模型应重写此方法而不是 __call__。
- 参数:
input_dict – 输入张量的字典,包括 “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, 和 “t”。
state – 状态张量列表,其大小与 get_initial_state 返回的大小 + 批量维度匹配
seq_lens – 1d 张量,持有输入序列的长度
- 返回:
一个由模型输出张量(大小为 [BATCH, num_outputs])和新的 RNN 状态列表(如果有)组成的元组。
import numpy as np from ray.rllib.models.modelv2 import ModelV2 class MyModel(ModelV2): # ... def forward(self, input_dict, state, seq_lens): model_out, self._value_out = self.base_model( input_dict["obs"]) return model_out, state