ray.rllib.policy.torch_policy_v2.TorchPolicyV2.额外梯度处理#
- TorchPolicyV2.extra_grad_process(optimizer: torch.optim.Optimizer, loss: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [源代码]#
在每次 optimizer.zero_grad() + loss.backward() 调用后被调用。
针对每个 self._optimizers/loss-value 对调用。允许在调用 optimizer.step() 之前进行梯度处理。例如,用于梯度裁剪。
- 参数:
optimizer – 一个 torch 优化器对象。
loss – 与优化器关联的损失张量。
- 返回:
包含梯度处理步骤信息的字典。