ray.rllib.core.learner.learner.Learner._update#

abstract Learner._update(batch: Dict[str, Any], **kwargs) Tuple[Any, Any, Any][源代码]#

包含图内/可追踪更新步骤的所有逻辑。

特定框架的子类必须实现此方法。这应包括对 RLModule 的 forward_traincompute_losscompute_gradientspostprocess_gradientsapply_gradients 方法的调用,并返回包含所有单独结果的元组。

参数:
  • batch – 训练批次已经转换为一个字典,映射字符串到(可能是嵌套的)张量。

  • kwargs – 向前兼容的关键字参数。

返回:

  1. RLModule 的 forward_train() 输出

  2. loss_per_module 字典将模块ID映射到各自的损失

    张量

  3. 一个将模块ID映射到指标键/值对的metrics字典。

返回类型:

A tuple consisting of