ray.rllib.core.learner.learner.Learner.compute_losses#

Learner.compute_losses(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) Dict[str, Any][源代码]#

计算正在优化的模块的损失。

此方法必须由 MultiRLModule 特定的学习者重写,以定义特定的损失计算逻辑。如果算法是单智能体,则应仅重写 compute_loss_for_module()。如果算法使用独立的多智能体学习(这是 RLlib 多智能体设置的默认行为),也应仅重写 compute_loss_for_module(),但它将为 MultiRLModule 内的每个单独 RLModule 调用。建议不要在此方法中计算任何前向传递,而是使用 RLModule(s) 的 forward_train() 输出计算所需的损失张量。请参阅此处自定义损失函数示例脚本:ray-project/ray

参数:
  • fwd_out – 在训练期间(self.update()),调用底层 MultiRLModule (self.module) 的 forward_train() 方法的输出。

  • batch – 用于计算 fwd_out 的训练批次。

返回:

一个将模块ID映射到各个损失项的字典。