ray.rllib.evaluation.rollout_worker.RolloutWorker.compute_gradients#
- RolloutWorker.compute_gradients(samples: SampleBatch | MultiAgentBatch | Dict[str, Any], single_agent: bool = None) Tuple[List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], dict] [源代码]#
返回相对于指定样本计算的梯度。
使用策略/策略的 compute_gradients 方法执行计算。根据
self.is_policy_to_train()
跳过不可训练的策略。- 参数:
samples – 使用此工作者的可训练策略计算梯度的 SampleBatch 或 MultiAgentBatch。
- 返回:
在单智能体情况下,由 ModelGradients 和工作者策略的信息字典组成的元组。在多智能体情况下,由一个映射 PolicyID 到 ModelGradients 的字典和一个映射 PolicyID 到额外元数据信息的字典组成的元组。请注意,第一个返回值(grads)可以直接应用于使用工作者的
apply_gradients()
方法的兼容工作者。
import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) batch = worker.sample() grads, info = worker.compute_gradients(samples)