ray.rllib.core.rl_模块.rl_模块.RL模块#
- class ray.rllib.core.rl_module.rl_module.RLModule(config: RLModuleConfig)[源代码]#
基类:
Checkpointable
,ABC
RLlib 模块的基类。
子类应在它们的 __init__ 方法中调用 super().__init__(config)。以下是调用 forward 方法的伪代码:
创建采样循环的示例:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() action_dist_class = module.get_inference_action_dist_cls() obs, info = env.reset() terminated = False while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_exploration(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
训练示例:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_train(fwd_ins) # loss = compute_loss(fwd_outputs, fwd_ins) # update_params(module, loss)
推理示例:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_inference(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
- 参数:
config – RLModule 的配置。
- 抽象方法:
~_forward_train
: 训练期间的正向传播。~_forward_exploration
: 训练期间的探索前向传递。~_forward_inference
: 推理过程中的前向传递。
备注
规范没有被写成抽象属性的原因是 torch 重载了
__getattr__
和__setattr__
。这意味着如果我们把规范定义为属性,那么属性中的任何错误都会被解释为获取属性的失败,并会调用__getattr__
,这将给出一个关于找不到属性的令人困惑的错误。更多细节请参见:pytorch/pytorch#49726。PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。
方法
返回此模块的多代理包装器。
探索期间的正向传递,从采样器调用。
在评估期间的前向传递,从采样器调用。
训练期间的前向传递,从学习器中调用。
从给定位置创建一个新的 Checkpointable 实例并返回它。
返回实现类自身的 Checkpointable 子组件。
返回用于探索的此 RLModule 的动作分布类。
返回用于推理的此 RLModule 的动作分布类。
返回 RLModule 的初始状态。
返回可写入的JSON元数据,进一步描述实现类。
返回模块的状态字典。
返回用于训练的此 RLModule 的动作分布类。
返回 forward_exploration 方法的输入规格。
返回 forward_inference 方法的输入规格。
返回 forward_train 方法的输入规格。
如果初始状态是一个空字典(或 None),则返回 False。
返回
forward_exploration()
方法的输出规格。返回
forward_inference()
方法的输出规格。返回 forward_train 方法的输出规格。
从给定的路径恢复实现类的状态。
将实现类的状态(或
state
)保存到path
。设置模块的组件。
如果此模块是一个包装器,则返回底层模块。
使用此模块的视图要求更新默认视图要求。
属性