ray.rllib.evaluation.rollout_worker.RolloutWorker.apply_gradients#
- RolloutWorker.apply_gradients(grads: List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] | Dict[str, List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]]) None [源代码]#
将给定的梯度应用于该工作者的模型。
使用策略/策略的 apply_gradients 方法来执行操作。
- 参数:
grads – 单个 ModelGradients(单代理情况)或一个映射 PolicyIDs 到各自模型梯度结构的字典。
import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) samples = worker.sample() grads, info = worker.compute_gradients(samples) worker.apply_gradients(grads)