ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.grad_stats_fn#
- EagerTFPolicyV2.grad_stats_fn(train_batch: SampleBatch, grads: List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [源代码]#
梯度统计函数。返回一个统计数据的字典。
- 参数:
train_batch – 用于训练的 SampleBatch(已使用)。
- 返回:
统计字典。