ray.rllib.core.learner.learner.Learner._update#
- abstract Learner._update(batch: Dict[str, Any], **kwargs) Tuple[Any, Any, Any] [源代码]#
包含图内/可追踪更新步骤的所有逻辑。
特定框架的子类必须实现此方法。这应包括对 RLModule 的
forward_train
、compute_loss
、compute_gradients
、postprocess_gradients
和apply_gradients
方法的调用,并返回包含所有单独结果的元组。- 参数:
batch – 训练批次已经转换为一个字典,映射字符串到(可能是嵌套的)张量。
kwargs – 向前兼容的关键字参数。
- 返回:
RLModule 的
forward_train()
输出- loss_per_module 字典将模块ID映射到各自的损失
张量
一个将模块ID映射到指标键/值对的metrics字典。
- 返回类型:
A tuple consisting of