ray.rllib.policy.torch_policy_v2.TorchPolicyV2.额外梯度处理#

TorchPolicyV2.extra_grad_process(optimizer: torch.optim.Optimizer, loss: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][源代码]#

在每次 optimizer.zero_grad() + loss.backward() 调用后被调用。

针对每个 self._optimizers/loss-value 对调用。允许在调用 optimizer.step() 之前进行梯度处理。例如,用于梯度裁剪。

参数:
  • optimizer – 一个 torch 优化器对象。

  • loss – 与优化器关联的损失张量。

返回:

包含梯度处理步骤信息的字典。