ray.rllib.core.learner.learner.Learner.计算梯度#

abstract Learner.compute_gradients(loss_per_module: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], **kwargs) Dict[Hashable, torch.Tensor | tf.Variable][源代码]#

基于给定的损失计算梯度。

参数:
  • loss_per_module – 字典映射模块ID到它们各自的总体损失项,由单独的 compute_loss_for_module() 调用计算。总体总损失(所有模块的损失项之和)存储在 loss_per_module[ALL_MODULES] 下。

  • **kwargs – 向前兼容的关键字参数。

返回:

与 self._params 相同的(扁平)格式中的梯度。请注意,所有顶级结构,如模块ID,将不再存在于返回的字典中。它只会将参数张量引用映射到它们各自的梯度张量。