ray.rllib.evaluation.rollout_worker.RolloutWorker.learn_on_batch#

RolloutWorker.learn_on_batch(samples: SampleBatch | MultiAgentBatch | Dict[str, Any]) Dict[源代码]#

根据给定的批次更新策略。

这相当于执行 apply_gradients(compute_gradients(samples)),但可以优化以避免将梯度拉入CPU内存。

参数:

samples – 用于学习的 SampleBatch 或 MultiAgentBatch。

返回:

来自 compute_gradients() 的额外元数据的字典。

import gymnasium as gym
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
worker = RolloutWorker(
  env_creator=lambda _: gym.make("CartPole-v1"),
  default_policy_class=PPOTF1Policy)
batch = worker.sample()
info = worker.learn_on_batch(samples)