备注

Ray 2.10.0 引入了 RLlib 的“新 API 栈”的 alpha 阶段。Ray 团队计划将算法、示例脚本和文档迁移到新的代码库中,从而在 Ray 3.0 之前的后续小版本中逐步替换“旧 API 栈”(例如,ModelV2、Policy、RolloutWorker)。

然而,请注意,到目前为止,只有 PPO(单代理和多代理)和 SAC(仅单代理)支持“新 API 堆栈”,并且默认情况下继续使用旧 API 运行。您可以继续使用现有的自定义(旧堆栈)类。

请参阅此处 以获取有关如何使用新API堆栈的更多详细信息。

LearnerGroup API#

配置 LearnerGroup 和 Learner Workers#

AlgorithmConfig.resources

指定为算法及其 ray 角色/工作者分配的资源。

AlgorithmConfig.rl_module

设置配置的 RLModule 设置。

AlgorithmConfig.training

设置与训练相关的配置。

构建一个学习小组#

AlgorithmConfig.build_learner_group

基于 self 中的设置构建并返回一个新的 LearnerGroup 对象。

LearnerGroup

n(可能是远程的)学习者工作者的协调者。

学习者 API#

构建学习者#

AlgorithmConfig.build_learner

基于 self 中的设置构建并返回一个新的 Learner 对象。

Learner

学习者的基类。

Learner.build

构建学习者。

Learner._check_is_built

Learner._make_module

构建学习者的多智能体强化学习模块。

执行更新#

Learner.update_from_batch

对给定的训练批次执行 num_iters 次小批量更新。

Learner.update_from_episodes

对给定的一系列片段执行 num_iters 次小批量更新。

Learner.before_gradient_based_update

在基于梯度的更新完成之前调用。

Learner._update

包含图内/可追踪更新步骤的所有逻辑。

Learner.after_gradient_based_update

在基于梯度的更新完成后调用。

计算损失#

Learner.compute_losses

计算正在优化的模块的损失。

Learner.compute_loss_for_module

计算单个模块的损失。

Learner._is_module_compatible_with_learner

检查模块是否与学习者兼容。

Learner._get_tensor_variable

返回一个特定于框架的张量变量,并赋予初始给定的值。

配置优化器#

Learner.configure_optimizers_for_module

为给定的 module_id 配置一个优化器。

Learner.configure_optimizers

配置、创建并注册此学习器的优化器。

Learner.register_optimizer

使用 ModuleID、名称、参数列表和学习率调度器注册一个优化器。

Learner.get_optimizers_for_module

返回一个 (优化器名称, 优化器实例) 元组列表,对应于 module_id。

Learner.get_optimizer

返回在给定的 module_id 和名称下配置的优化器对象。

Learner.get_parameters

返回模块的参数列表。

Learner.get_param_ref

返回一个可哈希的、对可训练参数的引用。

Learner.filter_param_dict_for_optimizer

将给定的 ParamDict 缩减为仅包含给定优化器的参数。

Learner._check_registered_optimizer

检查给定的优化器和参数对于框架是否有效。

Learner._set_optimizer_lr

更新给定本地优化器的学习率。

Learner._get_clip_function

根据框架返回要使用的梯度裁剪函数。

梯度计算#

Learner.compute_gradients

基于给定的损失计算梯度。

Learner.postprocess_gradients

对梯度应用潜在的后处理操作。

Learner.postprocess_gradients_for_module

对给定模块的梯度应用后处理操作。

Learner.apply_gradients

将梯度应用于 MultiRLModule 参数。

保存、加载、检查点以及恢复状态#

Learner.get_state

返回实现类的当前状态为字典。

Learner.set_state

将实现类的状态设置为给定的状态字典。

Learner.save_to_path

将实现类的状态(或 state)保存到 path

Learner.restore_from_path

从给定的路径恢复实现类的状态。

Learner.from_checkpoint

从给定位置创建一个新的 Checkpointable 实例并返回它。

Learner._get_optimizer_state

返回当前在此学习者中注册的所有优化器的状态。

Learner._set_optimizer_state

设置此学习器中当前注册的所有优化器的状态。

添加和移除模块#

Learner.add_module

将一个模块添加到底层的 MultiRLModule 中。

Learner.remove_module

从学习者中移除一个模块。