ray.rllib.evaluation.rollout_worker.RolloutWorker.learn_on_batch#
- RolloutWorker.learn_on_batch(samples: SampleBatch | MultiAgentBatch | Dict[str, Any]) Dict [源代码]#
根据给定的批次更新策略。
这相当于执行 apply_gradients(compute_gradients(samples)),但可以优化以避免将梯度拉入CPU内存。
- 参数:
samples – 用于学习的 SampleBatch 或 MultiAgentBatch。
- 返回:
来自 compute_gradients() 的额外元数据的字典。
import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) batch = worker.sample() info = worker.learn_on_batch(samples)