
Ray 2.10.0 引入了 RLlib 的“新 API 栈”的 alpha 阶段。Ray 团队计划将算法、示例脚本和文档迁移到新的代码库中,从而在 Ray 3.0 之前的后续小版本中逐步替换“旧 API 栈”(例如,ModelV2、Policy、RolloutWorker)。

然而,请注意,到目前为止,只有 PPO(单代理和多代理)和 SAC(仅单代理)支持“新 API 堆栈”,并且默认情况下继续使用旧 API 运行。您可以继续使用现有的自定义(旧堆栈)类。

请参阅此处 以获取有关如何使用新API堆栈的更多详细信息。

目录 (Alpha)#

Catalog 是一个实用抽象,它模块化了 RLModules 组件的构建。它包含了如何编码输入观察空间、应该使用什么动作分布等信息。Catalog。例如,PPOTorchRLModulePPOCatalog。要自定义现有的 RLModules,可以通过继承类并更改 setup() 方法直接更改 RLModule,或者,扩展分配给该 RLModule 的 Catalog 类。仅在您的自定义符合 Catalog 提供的抽象时使用 Catalog。


修改目录意味着高级用例,因此只有在修改 RLModule 或编写一个不能满足您的用例时才应考虑这一点。我们建议仅在更深入地定制决定 RLlib 默认创建的 ModelDistribution 的决策树时才修改目录。




这个字典(或其覆盖的子集)是 AlgorithmConfig 的一部分,因此也是任何特定算法配置的一部分。要改变 RLlib 默认模型的行为,请覆盖它并将其传递给 AlgorithmConfig。以改变 RLlib 默认模型的行为。

MODEL_DEFAULTS: ModelConfigDict = {
    # Experimental flag.
    # If True, user specified no preprocessor to be created
    # (via config._disable_preprocessor_api=True). If True, observations
    # will arrive in model as they are returned by the env.
    "_disable_preprocessor_api": False,
    # Experimental flag.
    # If True, RLlib will no longer flatten the policy-computed actions into
    # a single tensor (for storage in SampleCollectors/output files/etc..),
    # but leave (possibly nested) actions as-is. Disabling flattening affects:
    # - SampleCollectors: Have to store possibly nested action structs.
    # - Models that have the previous action(s) as part of their input.
    # - Algorithms reading from offline files (incl. action information).
    "_disable_action_flattening": False,

    # === Built-in options ===
    # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
    # These are used if no custom model is specified and the input space is 1D.
    # Number of hidden layers to be used.
    "fcnet_hiddens": [256, 256],
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "fcnet_activation": "tanh",
    # Initializer function or class descriptor for encoder weigths.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "fcnet_weights_initializer": None,
    # Initializer configuration for encoder weights.
    # This configuration is passed to the initializer defined in
    # `fcnet_weights_initializer`.
    "fcnet_weights_initializer_config": None,
    # Initializer function or class descriptor for encoder bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "fcnet_bias_initializer": None,
    # Initializer configuration for encoder bias.
    # This configuration is passed to the initializer defined in
    # `fcnet_bias_initializer`.
    "fcnet_bias_initializer_config": None,

    # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
    # These are used if no custom model is specified and the input space is 2D.
    # Filter config: List of [out_channels, kernel, stride] for each filter.
    # Example:
    # Use None for making RLlib try to find a default filter setup given the
    # observation space.
    "conv_filters": None,
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "conv_activation": "relu",
    # Initializer function or class descriptor for CNN encoder kernel.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_kernel_initializer": None,
    # Initializer configuration for CNN encoder kernel.
    # This configuration is passed to the initializer defined in
    # `conv_weights_initializer`.
    "conv_kernel_initializer_config": None,
    # Initializer function or class descriptor for CNN encoder bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_bias_initializer": None,
    # Initializer configuration for CNN encoder bias.
    # This configuration is passed to the initializer defined in
    # `conv_bias_initializer`.
    "conv_bias_initializer_config": None,
    # Initializer function or class descriptor for CNN head (pi, Q, V) kernel.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_transpose_kernel_initializer": None,
    # Initializer configuration for CNN head (pi, Q, V) kernel.
    # This configuration is passed to the initializer defined in
    # `conv_transpose_weights_initializer`.
    "conv_transpose_kernel_initializer_config": None,
    # Initializer function or class descriptor for CNN head (pi, Q, V) bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_transpose_bias_initializer": None,
    # Initializer configuration for CNN head (pi, Q, V) bias.
    # This configuration is passed to the initializer defined in
    # `conv_transpose_bias_initializer`.
    "conv_transpose_bias_initializer_config": None,

    # Some default models support a final FC stack of n Dense layers with given
    # activation:
    # - Complex observation spaces: Image components are fed through
    #   VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
    #   everything is concated and pushed through this final FC stack.
    # - VisionNets (CNNs), e.g. after the CNN stack, there may be
    #   additional Dense layers.
    # - FullyConnectedNetworks will have this additional FCStack as well
    # (that's why it's empty by default).
    "post_fcnet_hiddens": [],
    "post_fcnet_activation": "relu",
    # Initializer function or class descriptor for head (pi, Q, V) weights.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "post_fcnet_weights_initializer": None,
    # Initializer configuration for head (pi, Q, V) weights.
    # This configuration is passed to the initializer defined in
    # `post_fcnet_weights_initializer`.
    "post_fcnet_weights_initializer_config": None,
    # Initializer function or class descriptor for head (pi, Q, V) bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "post_fcnet_bias_initializer": None,
    # Initializer configuration for head (pi, Q, V) bias.
    # This configuration is passed to the initializer defined in
    # `post_fcnet_bias_initializer`.
    "post_fcnet_bias_initializer_config": None,

    # For DiagGaussian action distributions, make the second half of the model
    # outputs floating bias variables instead of state-dependent. This only
    # has an effect is using the default fully connected net.
    "free_log_std": False,
    # Whether to skip the final linear layer used to resize the hidden layer
    # outputs to size `num_outputs`. If True, then the last hidden layer
    # should already match num_outputs.
    "no_final_linear": False,
    # Whether layers should be shared for the value function.
    "vf_share_layers": True,

    # == LSTM ==
    # Whether to wrap the model with an LSTM.
    "use_lstm": False,
    # Max seq len for training the LSTM, defaults to 20.
    "max_seq_len": 20,
    # Size of the LSTM cell.
    "lstm_cell_size": 256,
    # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
    "lstm_use_prev_action": False,
    # Whether to feed r_{t-1} to LSTM.
    "lstm_use_prev_reward": False,
    # Initializer function or class descriptor for LSTM weights.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "lstm_weights_initializer": None,
    # Initializer configuration for LSTM weights.
    # This configuration is passed to the initializer defined in
    # `lstm_weights_initializer`.
    "lstm_weights_initializer_config": None,
    # Initializer function or class descriptor for LSTM bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "lstm_bias_initializer": None,
    # Initializer configuration for LSTM bias.
    # This configuration is passed to the initializer defined in
    # `lstm_bias_initializer`.
    "lstm_bias_initializer_config": None,
    # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
    "_time_major": False,

    # == Attention Nets (experimental: torch-version is untested) ==
    # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
    # wrapper Model around the default Model.
    "use_attention": False,
    # The number of transformer units within GTrXL.
    # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
    # b) a position-wise MLP.
    "attention_num_transformer_units": 1,
    # The input and output size of each transformer unit.
    "attention_dim": 64,
    # The number of attention heads within the MultiHeadAttention units.
    "attention_num_heads": 1,
    # The dim of a single head (within the MultiHeadAttention units).
    "attention_head_dim": 32,
    # The memory sizes for inference and training.
    "attention_memory_inference": 50,
    "attention_memory_training": 50,
    # The output dim of the position-wise MLP.
    "attention_position_wise_mlp_dim": 32,
    # The initial bias values for the 2 GRU gates within a transformer unit.
    "attention_init_gru_gate_bias": 2.0,
    # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
    "attention_use_n_prev_actions": 0,
    # Whether to feed r_{t-n:t-1} to GTrXL.
    "attention_use_n_prev_rewards": 0,

    # == Atari ==
    # Set to True to enable 4x stacking behavior.
    "framestack": True,
    # Final resized frame dimension
    "dim": 84,
    # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
    "grayscale": False,
    # (deprecated) Changes frame to range from [-1, 1] if true
    "zero_mean": True,

    # === Options for custom models ===
    # Name of a custom model to use
    "custom_model": None,
    # Extra options to pass to the custom classes. These will be available to
    # the Model's constructor in the model_config field. Also, they will be
    # attempted to be passed as **kwargs to ModelV2 models. For an example,
    # see rllib/models/[tf|torch]/attention_net.py.
    "custom_model_config": {},
    # Name of a custom action distribution to use.
    "custom_action_dist": None,
    # Custom preprocessors are deprecated. Please use a wrapper class around
    # your environment instead to preprocess observations.
    "custom_preprocessor": None,

    # === Options for ModelConfigs in RLModules ===
    # The latent dimension to encode into.
    # Since most RLModules have an encoder and heads, this establishes an agreement
    # on the dimensionality of the latent space they share.
    # This has no effect for models outside RLModule.
    # If None, model_config['fcnet_hiddens'][-1] value will be used to guarantee
    # backward compatibility to old configs. This yields different models than past
    # versions of RLlib.
    "encoder_latent_dim": None,
    # Whether to always check the inputs and outputs of RLlib's default models for
    # their specifications. Input specifications are checked on failed forward passes
    # of the models regardless of this flag. If this flag is set to `True`, inputs and
    # outputs are checked on every call. This leads to a slow-down and should only be
    # used for debugging. Note that this flag is only relevant for instances of
    # RLlib's Model class. These are commonly generated from ModelConfigs in RLModules.
    "always_check_shapes": False,

    # Deprecated keys:
    # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
    "lstm_use_prev_action_reward": DEPRECATED_VALUE,
    # Deprecated in anticipation of RLModules API
    "_use_default_native_models": DEPRECATED_VALUE,


虽然目录有一个基类 Catalog,但你主要与特定于算法的目录进行交互。因此,本文档还包括了围绕 PPO 的示例,你可以从中推断到其他算法。本用户指南的先决条件是对 RLModules 有一个大致的了解。本用户指南涵盖以下主题:

  • 什么是目录

  • 目录设计和想法

  • 目录和算法配置

  • 基本用法

  • 将您的自定义模型注入 RLModules

  • 将自定义动作分布注入 RLModules

  • 从头编写目录


目录有两个主要角色:选择合适的 Model 和选择合适的 Distribution。默认情况下,所有目录都实现了决策树,这些决策树根据输入配置的组合来决定模型架构。这些主要包括 RLModuleobservation spaceaction space,以及 model config dict深度学习框架后端

下图展示了信息流向 RLModule 中的 modelsdistributions 的分解。RLModule 在其构造函数中接收的 Catalog 类创建一个实例。然后,它借助这个 Catalog 创建其内部的 modelsdistributions


你也可以通过重写 RLModule 的构造函数直接修改模型或分布!



PPORLModule 中的目录示例

The PPOCatalog 接收一个 observation spaceaction space、一个 model config dict 以及 RLModuleview requirementsmodel config dictsview requirements 仅在特殊情况下(如循环网络或注意力网络)感兴趣。PPORLModule 有四个组件,由 PPOCatalog 创建:Encodervalue function headpolicy headaction distribution





RL 算法需要神经网络 模型分布。在一个算法中,这些子组件的许多不同架构都是有效的。此外,模型和分布随环境变化。然而,大多数算法需要的模型具有相似性。问题是在广泛的用例中找到合理的子组件,同时在这些算法之间共享此功能。


如上所述,目录为 RLModules 的子组件实现了决策树。目录对象中的模型和分布旨在相互配合。由于我们主要使用 Encoder 、Heads 和 Distribution 构建 RLModules,目录通常也反映了这一点。例如,PPOCatalog 将输出一个输出潜在向量的编码器和两个以该潜在向量为输入的 Heads。(这就是为什么目录有一个 latent_dims 属性)。Heads 和分布的行为也相应地调整。每当你创建一个目录时,决策树就会执行以找到适合模型和分布类的配置。默认情况下,这发生在 _get_encoder_config()_get_dist_cls_from_action_space() 中。每当你构建一个模型时,配置就会被转换为模型。分布在 RLModule 的每次前向传递时实例化,因此不会构建。

API 哲学#

目录尝试将模型内部的大部分复杂性封装在 Encoder 中。这意味着递归、注意力和其他特殊情况在编码器内部得到完全处理,并且对其他组件是透明的。编码器是目录基类构建的唯一组件。这是因为许多算法需要自定义的头部和分布,但它们中的大多数可以使用相同的编码器。目录API的设计使得交互通常分为两个阶段:

  • 实例化一个目录。这将执行决策树。

  • 通过目录方法生成任意数量的决定组件。


你可以重写这些方法来快速修改 RLModules 构建的模型。其他方法是私有的,只有在需要对决策树进行深度修改以增强 Catalogs 的功能时才应重写。此外,get_tokenizer_config() 是一个可以在需要分词时使用的方法。分词意味着单步嵌入。编码也意味着嵌入,但可以跨越多个时间步。事实上,RLlib 在其循环编码器(例如 TorchLSTMEncoder)中使用的分词器,是非循环编码器类的实例。


由于目录有效地控制了RLlib在底层使用的``models``和``distributions``,它们也是RLlib配置的一部分。作为配置RLlib的主要入口点,AlgorithmConfig 是你配置创建的RLModules的目录的地方。你可以通过 RLModuleSpecMultiRLModuleSpec 来设置``catalog class``。例如,在异构多智能体情况下,你可以修改MultiRLModuleSpec。


以下示例展示了如何配置由PPO创建的 RLModule 的目录。

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

class MyPPOCatalog(PPOCatalog):
    def __init__(self, *args, **kwargs):
        print("Hi from within PPORLModule!")
        super().__init__(*args, **kwargs)

config = (

# Specify the catalog to use for the PPORLModule.
config = config.rl_module(rl_module_spec=RLModuleSpec(catalog_class=MyPPOCatalog))
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()



高级 API#


import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")


第二个示例展示了如何使用基础的 Catalog 来创建一个 模型 和一个 动作分布。除此之外,我们还手动创建了一个 头部网络,以手动适应这两个组件。

import gymnasium as gym
import torch

# ENCODER_OUT is a constant we use to enumerate Encoder I/O.
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog

env = gym.make("CartPole-v1")

catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
# We need `env.action_space.n` action distribution inputs.
head = torch.nn.Linear(catalog.latent_dims[0], env.action_space.n)
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {Columns.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()


第三个示例展示了如何使用 PPOCatalog 来创建一个 encoder 和一个 action distribution。这与 RLlib 内部的实现更为相似。

import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

# STATE_IN, STATE_OUT and ENCODER_OUT are constants we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()

请注意,以上两个示例原则上说明了实现目录所需的内容。在这种情况下,我们看到了 CatalogPPOCatalog 之间的区别。在大多数情况下,我们可以重用基础 Catalog 基类的功能,并且只需要添加方法来构建头部网络,然后我们可以在适当的 RLModule 中使用这些网络。


你可以通过重写 Catalog 中由 RLModules 用于构建 models 的方法,来制作自定义的 models。查看 PPOTorchRLModule 的构造函数中的这些行,了解 RLModule 是如何使用 Catalogs 的。

        catalog = self.config.get_catalog()
        # If we have a stateful model, states for the critic need to be collected
        # during sampling and `inference-only` needs to be `False`. Note, at this
        # point the encoder is not built, yet and therefore `is_stateful()` does
        # not work.
        is_stateful = isinstance(
        if is_stateful:
            self.config.inference_only = False
        # If this is an `inference_only` Module, we'll have to pass this information
        # to the encoder config as well.
        if self.config.inference_only and self.framework == "torch":
            catalog.actor_critic_encoder_config.inference_only = True

        # Build models from catalog.
        self.encoder = catalog.build_actor_critic_encoder(framework=self.framework)
        self.pi = catalog.build_pi_head(framework=self.framework)
        self.vf = catalog.build_vf_head(framework=self.framework)

        self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)

注意,PPOTorchRLModule 构造函数内部发生的事情与前面的示例 为 PPO 创建模型和分布 类似。

因此,为了构建一个与 PPORLModule 兼容的自定义 Model,你可以通过继承 PPOCatalog 来重写方法,或者从头开始编写一个实现这些方法的 Catalog。以下示例展示了此类修改:


  • 如何编写自定义的 Encoder

  • 如何将自定义编码器注入到 Catalog

请注意,如果您只想将编码器注入到一个 RLModule 中,推荐的流程是从现有的 RL 模块继承并将编码器放置在那里。

import gymnasium as gym
import numpy as np

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
from ray.rllib.examples.envs.classes.random_env import RandomEnv

# Define a PPO Catalog that we can use to inject our MobileNetV2 Encoder into RLlib's
# decision tree of what model to choose
class MobileNetEnhancedPPOCatalog(PPOCatalog):
    def _get_encoder_config(
        observation_space: gym.Space,
        if (
            isinstance(observation_space, gym.spaces.Box)
            and observation_space.shape == MOBILENET_INPUT_SHAPE
            # Inject our custom encoder here, only if the observation space fits it
            return MobileNetV2EncoderConfig()
            return super()._get_encoder_config(observation_space, **kwargs)

# Create a generic config with our enhanced Catalog
ppo_config = (
    # The following training settings make it so that a training iteration is very
    # quick. This is just for the sake of this example. PPO will not learn properly
    # with these settings!
    .training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1)

# CartPole's observation space is not compatible with our MobileNetV2 Encoder, so
# this will use the default behaviour of Catalogs
results = ppo_config.build().train()

# For this training, we use a RandomEnv with observations of shape
# MOBILENET_INPUT_SHAPE. This will use our custom Encoder.
        "action_space": gym.spaces.Discrete(2),
        # Test a simple Image observation space.
        "observation_space": gym.spaces.Box(
results = ppo_config.build().train()


import torch
import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.models.distributions import Distribution
from ray.rllib.models.torch.torch_distributions import TorchDeterministic

# Define a simple categorical distribution that can be used for PPO
class CustomTorchCategorical(Distribution):
    def __init__(self, logits):
        self.torch_dist = torch.distributions.categorical.Categorical(logits=logits)

    def sample(self, sample_shape=torch.Size(), **kwargs):
        return self.torch_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size(), **kwargs):
        return self._dist.rsample(sample_shape)

    def logp(self, value, **kwargs):
        return self.torch_dist.log_prob(value)

    def entropy(self):
        return self.torch_dist.entropy()

    def kl(self, other, **kwargs):
        return torch.distributions.kl.kl_divergence(self.torch_dist, other.torch_dist)

    def required_input_dim(space, **kwargs):
        return int(space.n)

    # This method is used to create distributions from logits inside RLModules.
    # You can use this to inject arguments into the constructor of this distribution
    # that are not the logits themselves.
    def from_logits(cls, logits):
        return CustomTorchCategorical(logits=logits)

    # This method is used to create a deterministic distribution for the
    # PPORLModule.forward_inference.
    def to_deterministic(self):
        return TorchDeterministic(loc=torch.argmax(self.logits, dim=-1))

# See if we can create this distribution and sample from it to interact with our
# target environment
env = gym.make("CartPole-v1")
dummy_logits = torch.randn([env.action_space.n])
dummy_dist = CustomTorchCategorical.from_logits(dummy_logits)
action = dummy_dist.sample()
env = gym.make("CartPole-v1")

# Define a simple catalog that returns our custom distribution when
# get_action_dist_cls is called
class CustomPPOCatalog(PPOCatalog):
    def get_action_dist_cls(self, framework):
        # The distribution we wrote will only work with torch
        assert framework == "torch"
        return CustomTorchCategorical

# Train with our custom action distribution
algo = (
results = algo.train()

这些示例针对PPO,但工作流程适用于所有RLlib算法。请注意,PPO向基类添加了 from ray.rllib.core.models.base.ActorCriticEncoder 和两个头(策略头和价值头)。您可以类似地覆盖这些。其他算法可能会添加不同的子组件或覆盖默认组件。


只有在您想在 RLlib 下编写新算法时才需要这个。请注意,编写算法并不严格要求编写新的目录,但您可以使用目录作为工具来创建合适的默认子组件,例如模型或分布。以下是编写新目录的典型要求和步骤:


PPORLModules 目录
import gymnasium as gym

from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import (
from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model
from ray.rllib.utils import override
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic

def _check_if_diag_gaussian(action_distribution_cls, framework):
    if framework == "torch":
        from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian

        assert issubclass(action_distribution_cls, TorchDiagGaussian), (
            f"free_log_std is only supported for DiagGaussian action distributions. "
            f"Found action distribution: {action_distribution_cls}."
    elif framework == "tf2":
        from ray.rllib.models.tf.tf_distributions import TfDiagGaussian

        assert issubclass(action_distribution_cls, TfDiagGaussian), (
            "free_log_std is only supported for DiagGaussian action distributions. "
            "Found action distribution: {}.".format(action_distribution_cls)
        raise ValueError(f"Framework {framework} not supported for free_log_std.")

class PPOCatalog(Catalog):
    """The Catalog class used to build models for PPO.

    PPOCatalog provides the following models:
        - ActorCriticEncoder: The encoder used to encode the observations.
        - Pi Head: The head used to compute the policy logits.
        - Value Function Head: The head used to compute the value function.

    The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
    for the policy and value function. See implementations of PPORLModule for
    more details.

    Any custom ActorCriticEncoder can be built by overriding the
    build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
    at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
    ActorCriticEncoder during RLModule runtime.

    Any custom head can be built by overriding the build_pi_head() and build_vf_head()
    methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
    build custom heads during RLModule runtime.

    Any module built for exploration or inference is built with the flag
    `ìnference_only=True` and does not contain a value network. This flag can be set
    in the `SingleAgentModuleSpec` through the `inference_only` boolean flag.
    In case that the actor-critic-encoder is not shared between the policy and value
    function, the inference-only module will contain only the actor encoder network.

    def __init__(
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
        """Initializes the PPOCatalog.

            observation_space: The observation space of the Encoder.
            action_space: The action space for the Pi Head.
            model_config_dict: The model config to use.

        # Replace EncoderConfig by ActorCriticEncoderConfig
        self.actor_critic_encoder_config = ActorCriticEncoderConfig(

        self.pi_and_vf_head_hiddens = self._model_config_dict["post_fcnet_hiddens"]
        self.pi_and_vf_head_activation = self._model_config_dict[

        # We don't have the exact (framework specific) action dist class yet and thus
        # cannot determine the exact number of output nodes (action space) required.
        # -> Build pi config only in the `self.build_pi_head` method.
        self.pi_head_config = None

        self.vf_head_config = MLPHeadConfig(

    def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder:
        """Builds the ActorCriticEncoder.

        The default behavior is to build the encoder from the encoder_config.
        This can be overridden to build a custom ActorCriticEncoder as a means of
        configuring the behavior of a PPORLModule implementation.

            framework: The framework to use. Either "torch" or "tf2".

            The ActorCriticEncoder.
        return self.actor_critic_encoder_config.build(framework=framework)

    def build_encoder(self, framework: str) -> Encoder:
        """Builds the encoder.

        Since PPO uses an ActorCriticEncoder, this method should not be implemented.
        raise NotImplementedError(
            "Use PPOCatalog.build_actor_critic_encoder() instead for PPO."

    def build_pi_head(self, framework: str) -> Model:
        """Builds the policy head.

        The default behavior is to build the head from the pi_head_config.
        This can be overridden to build a custom policy head as a means of configuring
        the behavior of a PPORLModule implementation.

            framework: The framework to use. Either "torch" or "tf2".

            The policy head.
        # Get action_distribution_cls to find out about the output dimension for pi_head
        action_distribution_cls = self.get_action_dist_cls(framework=framework)
        if self._model_config_dict["free_log_std"]:
                action_distribution_cls=action_distribution_cls, framework=framework
        required_output_dim = action_distribution_cls.required_input_dim(
            space=self.action_space, model_config=self._model_config_dict
        # Now that we have the action dist class and number of outputs, we can define
        # our pi-config and build the pi head.
        pi_head_config_class = (
            if self._model_config_dict["free_log_std"]
            else MLPHeadConfig
        self.pi_head_config = pi_head_config_class(

        return self.pi_head_config.build(framework=framework)

    def build_vf_head(self, framework: str) -> Model:
        """Builds the value function head.

        The default behavior is to build the head from the vf_head_config.
        This can be overridden to build a custom value function head as a means of
        configuring the behavior of a PPORLModule implementation.

            framework: The framework to use. Either "torch" or "tf2".

            The value function head.
        return self.vf_head_config.build(framework=framework)