ray.rllib.core.rl_模块.rl_模块.RL模块#

class ray.rllib.core.rl_module.rl_module.RLModule(config: RLModuleConfig)[源代码]#

基类:Checkpointable, ABC

RLlib 模块的基类。

子类应在它们的 __init__ 方法中调用 super().__init__(config)。以下是调用 forward 方法的伪代码:

创建采样循环的示例:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = RLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()
action_dist_class = module.get_inference_action_dist_cls()
obs, info = env.reset()
terminated = False

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_exploration(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)

训练示例:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = RLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

fwd_ins = {"obs": torch.Tensor([obs])}
fwd_outputs = module.forward_train(fwd_ins)
# loss = compute_loss(fwd_outputs, fwd_ins)
# update_params(module, loss)

推理示例:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = RLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_inference(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)
参数:

config – RLModule 的配置。

抽象方法:

~_forward_train: 训练期间的正向传播。

~_forward_exploration: 训练期间的探索前向传递。

~_forward_inference: 推理过程中的前向传递。

备注

规范没有被写成抽象属性的原因是 torch 重载了 __getattr____setattr__。这意味着如果我们把规范定义为属性,那么属性中的任何错误都会被解释为获取属性的失败,并会调用 __getattr__,这将给出一个关于找不到属性的令人困惑的错误。更多细节请参见:pytorch/pytorch#49726

PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。

方法

as_multi_rl_module

返回此模块的多代理包装器。

forward_exploration

探索期间的正向传递,从采样器调用。

forward_inference

在评估期间的前向传递,从采样器调用。

forward_train

训练期间的前向传递,从学习器中调用。

from_checkpoint

从给定位置创建一个新的 Checkpointable 实例并返回它。

get_checkpointable_components

返回实现类自身的 Checkpointable 子组件。

get_exploration_action_dist_cls

返回用于探索的此 RLModule 的动作分布类。

get_inference_action_dist_cls

返回用于推理的此 RLModule 的动作分布类。

get_initial_state

返回 RLModule 的初始状态。

get_metadata

返回可写入的JSON元数据,进一步描述实现类。

get_state

返回模块的状态字典。

get_train_action_dist_cls

返回用于训练的此 RLModule 的动作分布类。

input_specs_exploration

返回 forward_exploration 方法的输入规格。

input_specs_inference

返回 forward_inference 方法的输入规格。

input_specs_train

返回 forward_train 方法的输入规格。

is_stateful

如果初始状态是一个空字典(或 None),则返回 False。

output_specs_exploration

返回 forward_exploration() 方法的输出规格。

output_specs_inference

返回 forward_inference() 方法的输出规格。

output_specs_train

返回 forward_train 方法的输出规格。

restore_from_path

从给定的路径恢复实现类的状态。

save_to_path

将实现类的状态(或 state)保存到 path

setup

设置模块的组件。

unwrapped

如果此模块是一个包装器,则返回底层模块。

update_default_view_requirements

使用此模块的视图要求更新默认视图要求。

属性

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

framework