ray.rllib.core.learner.learner.Learner#

class ray.rllib.core.learner.learner.Learner(*, config: AlgorithmConfig, module_spec: RLModuleSpec | MultiRLModuleSpec | None = None, module: RLModule | None = None)[源代码]#

基类:Checkpointable

学习者的基类。

此类将用于训练 RLModules。它负责定义损失函数,并更新其拥有的神经网络权重。它还提供了一种在多智能体场景中训练过程中添加/移除 RLModules 模块的方法(这对于基于联盟的训练非常有用)。

此类针对 TensorFlow 和 Torch 的具体实现,填充了分布式训练以及计算和应用梯度的框架特定实现细节。用户不需要子类化此类,而是应继承 TensorFlow 或 Torch 特定的子类来实现其算法特定的更新逻辑。

参数:
  • config – AlgorithmConfig 对象,从中派生出构建 Learner 所需的大部分设置。

  • module_spec – 正在训练的 RLModule 的模块规范。如果该模块是单一代理模块,构建模块后它将被转换为具有默认键的多代理模块。如果模块是通过 module 参数直接提供的,则可以为空。更多信息请参考 ray.rllib.core.rl_module.RLModuleSpec 或 ray.rllib.core.rl_module.MultiRLModuleSpec。

  • module – 如果 learner 是独立使用的,可以选择直接传入 RLModule 而不是通过 module_spec

注意:我们在这里使用 PPO 和 torch 作为示例,因为许多展示的组件需要实现才能结合在一起。然而,同样的模式通常是适用的。

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

env = gym.make("CartPole-v1")

# Create a PPO config object first.
config = (
    PPOConfig()
    .framework("torch")
    .training(model={"fcnet_hiddens": [128, 128]})
)

# Create a learner instance directly from our config. All we need as
# extra information here is the env to be able to extract space information
# (needed to construct the RLModule inside the Learner).
learner = config.build_learner(env=env)

# Take one gradient update on the module and report the results.
# results = learner.update(...)

# Add a new module, perhaps for league based training.
learner.add_module(
    module_id="new_player",
    module_spec=RLModuleSpec(
        module_class=PPOTorchRLModule,
        observation_space=env.observation_space,
        action_space=env.action_space,
        model_config_dict={"fcnet_hiddens": [64, 64]},
        catalog_class=PPOCatalog,
    )
)

# Take another gradient update with both previous and new modules.
# results = learner.update(...)

# Remove a module.
learner.remove_module("new_player")

# Will train previous modules only.
# results = learner.update(...)

# Get the state of the learner.
state = learner.get_state()

# Set the state of the learner.
learner.set_state(state)

# Get the weights of the underlying MultiRLModule.
weights = learner.get_state(components=COMPONENT_RL_MODULE)

# Set the weights of the underlying MultiRLModule.
learner.set_state({COMPONENT_RL_MODULE: weights})

扩展模式:

from ray.rllib.core.learner.torch.torch_learner import TorchLearner

class MyLearner(TorchLearner):

   def compute_losses(self, fwd_out, batch):
       # Compute the losses per module based on `batch` and output of the
       # forward pass (`fwd_out`). To access the (algorithm) config for a
       # specific RLModule, do:
       # `self.config.get_config_for_module([moduleID])`.
       return {DEFAULT_MODULE_ID: module_loss}

PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。

方法

add_module

将一个模块添加到底层的 MultiRLModule 中。

after_gradient_based_update

在基于梯度的更新完成后调用。

apply_gradients

将梯度应用于 MultiRLModule 参数。

before_gradient_based_update

在基于梯度的更新完成之前调用。

build

构建学习者。

compute_gradients

基于给定的损失计算梯度。

compute_loss_for_module

计算单个模块的损失。

compute_losses

计算正在优化的模块的损失。

configure_optimizers

配置、创建并注册此学习器的优化器。

configure_optimizers_for_module

为给定的 module_id 配置一个优化器。

filter_param_dict_for_optimizer

将给定的 ParamDict 缩减为仅包含给定优化器的参数。

from_checkpoint

从给定位置创建一个新的 Checkpointable 实例并返回它。

get_metadata

返回可写入的JSON元数据,进一步描述实现类。

get_optimizer

返回在给定的 module_id 和名称下配置的优化器对象。

get_optimizers_for_module

返回一个 (优化器名称, 优化器实例) 元组列表,对应于 module_id。

get_param_ref

返回一个可哈希的、对可训练参数的引用。

get_parameters

返回模块的参数列表。

postprocess_gradients

对梯度应用潜在的后处理操作。

postprocess_gradients_for_module

对给定模块的梯度应用后处理操作。

register_optimizer

使用 ModuleID、名称、参数列表和学习率调度器注册一个优化器。

remove_module

从学习者中移除一个模块。

restore_from_path

从给定的路径恢复实现类的状态。

save_to_path

将实现类的状态(或 state)保存到 path

should_module_be_updated

根据 self.config 返回模块是否应更新。

update_from_batch

对给定的训练批次执行 num_iters 次小批量更新。

update_from_episodes

对给定的一系列片段执行 num_iters 次小批量更新。

属性

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

TOTAL_LOSS_KEY

distributed

学习者是否在分布式模式下运行。

framework

module

正在训练的 MultiRLModule。