ray.rllib.policy.Policy.compute_gradients#
- Policy.compute_gradients(postprocessed_batch: SampleBatch) Tuple[List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] [源代码]#
在给定一批经验的情况下计算梯度。
子类必须实现此方法与
apply_gradients()
或learn_on_batch()
的组合。- 参数:
postprocessed_batch – 用于计算梯度的 SampleBatch 对象。
- 返回:
梯度输出值列表。grad_info: 额外的策略特定信息值。
- 返回类型:
grads