ray.rllib.core.learner.learner.Learner.compute_losses#
- Learner.compute_losses(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) Dict[str, Any] [源代码]#
计算正在优化的模块的损失。
此方法必须由 MultiRLModule 特定的学习者重写,以定义特定的损失计算逻辑。如果算法是单智能体,则应仅重写
compute_loss_for_module()
。如果算法使用独立的多智能体学习(这是 RLlib 多智能体设置的默认行为),也应仅重写compute_loss_for_module()
,但它将为 MultiRLModule 内的每个单独 RLModule 调用。建议不要在此方法中计算任何前向传递,而是使用 RLModule(s) 的forward_train()
输出计算所需的损失张量。请参阅此处自定义损失函数示例脚本:ray-project/ray- 参数:
fwd_out – 在训练期间(
self.update()
),调用底层 MultiRLModule (self.module
) 的forward_train()
方法的输出。batch – 用于计算
fwd_out
的训练批次。
- 返回:
一个将模块ID映射到各个损失项的字典。