ray.rllib.policy.Policy.learn_on_batch#

Policy.learn_on_batch(samples: SampleBatch) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][源代码]#

执行一次学习更新,基于 samples

子类必须实现此方法或 compute_gradientsapply_gradients 的组合。

参数:

samples – 用于学习的 SampleBatch 对象。

返回:

来自 compute_gradients() 的额外元数据的字典。

policy, sample_batch = ...
policy.learn_on_batch(sample_batch)