ray.rllib.core.learner.learner.Learner#
- class ray.rllib.core.learner.learner.Learner(*, config: AlgorithmConfig, module_spec: RLModuleSpec | MultiRLModuleSpec | None = None, module: RLModule | None = None)[源代码]#
-
学习者的基类。
此类将用于训练 RLModules。它负责定义损失函数,并更新其拥有的神经网络权重。它还提供了一种在多智能体场景中训练过程中添加/移除 RLModules 模块的方法(这对于基于联盟的训练非常有用)。
此类针对 TensorFlow 和 Torch 的具体实现,填充了分布式训练以及计算和应用梯度的框架特定实现细节。用户不需要子类化此类,而是应继承 TensorFlow 或 Torch 特定的子类来实现其算法特定的更新逻辑。
- 参数:
config – AlgorithmConfig 对象,从中派生出构建 Learner 所需的大部分设置。
module_spec – 正在训练的 RLModule 的模块规范。如果该模块是单一代理模块,构建模块后它将被转换为具有默认键的多代理模块。如果模块是通过
module
参数直接提供的,则可以为空。更多信息请参考 ray.rllib.core.rl_module.RLModuleSpec 或 ray.rllib.core.rl_module.MultiRLModuleSpec。module – 如果 learner 是独立使用的,可以选择直接传入 RLModule 而不是通过
module_spec
。
注意:我们在这里使用 PPO 和 torch 作为示例,因为许多展示的组件需要实现才能结合在一起。然而,同样的模式通常是适用的。
import gymnasium as gym from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID from ray.rllib.core.rl_module.rl_module import RLModuleSpec env = gym.make("CartPole-v1") # Create a PPO config object first. config = ( PPOConfig() .framework("torch") .training(model={"fcnet_hiddens": [128, 128]}) ) # Create a learner instance directly from our config. All we need as # extra information here is the env to be able to extract space information # (needed to construct the RLModule inside the Learner). learner = config.build_learner(env=env) # Take one gradient update on the module and report the results. # results = learner.update(...) # Add a new module, perhaps for league based training. learner.add_module( module_id="new_player", module_spec=RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict={"fcnet_hiddens": [64, 64]}, catalog_class=PPOCatalog, ) ) # Take another gradient update with both previous and new modules. # results = learner.update(...) # Remove a module. learner.remove_module("new_player") # Will train previous modules only. # results = learner.update(...) # Get the state of the learner. state = learner.get_state() # Set the state of the learner. learner.set_state(state) # Get the weights of the underlying MultiRLModule. weights = learner.get_state(components=COMPONENT_RL_MODULE) # Set the weights of the underlying MultiRLModule. learner.set_state({COMPONENT_RL_MODULE: weights})
扩展模式:
from ray.rllib.core.learner.torch.torch_learner import TorchLearner class MyLearner(TorchLearner): def compute_losses(self, fwd_out, batch): # Compute the losses per module based on `batch` and output of the # forward pass (`fwd_out`). To access the (algorithm) config for a # specific RLModule, do: # `self.config.get_config_for_module([moduleID])`. return {DEFAULT_MODULE_ID: module_loss}
PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。
方法
将一个模块添加到底层的 MultiRLModule 中。
在基于梯度的更新完成后调用。
将梯度应用于 MultiRLModule 参数。
在基于梯度的更新完成之前调用。
构建学习者。
基于给定的损失计算梯度。
计算单个模块的损失。
计算正在优化的模块的损失。
配置、创建并注册此学习器的优化器。
为给定的 module_id 配置一个优化器。
将给定的 ParamDict 缩减为仅包含给定优化器的参数。
从给定位置创建一个新的 Checkpointable 实例并返回它。
返回可写入的JSON元数据,进一步描述实现类。
返回在给定的 module_id 和名称下配置的优化器对象。
返回一个 (优化器名称, 优化器实例) 元组列表,对应于 module_id。
返回一个可哈希的、对可训练参数的引用。
返回模块的参数列表。
对梯度应用潜在的后处理操作。
对给定模块的梯度应用后处理操作。
使用 ModuleID、名称、参数列表和学习率调度器注册一个优化器。
从学习者中移除一个模块。
从给定的路径恢复实现类的状态。
将实现类的状态(或
state
)保存到path
。根据
self.config
返回模块是否应更新。对给定的训练批次执行
num_iters
次小批量更新。对给定的一系列片段执行
num_iters
次小批量更新。属性
学习者是否在分布式模式下运行。
正在训练的 MultiRLModule。