ray.rllib.core.learner.learner.Learner.计算梯度#
- abstract Learner.compute_gradients(loss_per_module: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], **kwargs) Dict[Hashable, torch.Tensor | tf.Variable] [源代码]#
基于给定的损失计算梯度。
- 参数:
loss_per_module – 字典映射模块ID到它们各自的总体损失项,由单独的
compute_loss_for_module()
调用计算。总体总损失(所有模块的损失项之和)存储在loss_per_module[ALL_MODULES]
下。**kwargs – 向前兼容的关键字参数。
- 返回:
与 self._params 相同的(扁平)格式中的梯度。请注意,所有顶级结构,如模块ID,将不再存在于返回的字典中。它只会将参数张量引用映射到它们各自的梯度张量。