备注

Ray 2.10.0 引入了 RLlib 的“新 API 栈”的 alpha 阶段。Ray 团队计划将算法、示例脚本和文档迁移到新的代码库中,从而在 Ray 3.0 之前的后续小版本中逐步替换“旧 API 栈”(例如,ModelV2、Policy、RolloutWorker)。

然而,请注意,到目前为止,只有 PPO(单代理和多代理)和 SAC(仅单代理)支持“新 API 堆栈”,并且默认情况下继续使用旧 API 运行。您可以继续使用现有的自定义(旧堆栈)类。

请参阅此处 以获取有关如何使用新API堆栈的更多详细信息。

模型API#

基础模型类#

ModelV2

定义一个用于 RLlib 的抽象神经网络模型。

TorchModelV2

ModelV2 的 Torch 版本。

TFModelV2

ModelV2 的 TF 版本,应包含一个 tf keras 模型。

前馈方法#

forward

使用给定的输入张量和状态调用模型。

value_function

返回最近一次前向传递的值函数输出。

last_output

返回从调用模型返回的最后一个输出。

循环模型 API#

get_initial_state

获取模型的初始递归状态值。

is_time_major

如果为 True,调用此 ModelV2 的数据必须为时间优先格式。

访问变量#

variables

返回此模型的变量列表(或字典)。

trainable_variables

返回此模型的可训练变量列表。

Distribution

分布在随机变量上的基类。

自定义#

custom_loss

覆盖以自定义用于优化此模型的损失函数。

metrics

覆盖以从您的模型返回自定义指标。