ray.rllib.models.modelv2.ModelV2#

class ray.rllib.models.modelv2.ModelV2(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: dict, name: str, framework: str)[源代码]#

基类:object

定义一个用于 RLlib 的抽象神经网络模型。

自定义模型应扩展 TFModelV2 或 TorchModelV2,而不是直接扩展此类。

数据流:
obs -> forward() -> model_out

-> value_function() -> V(s)

方法

__init__

初始化一个 ModelV2 实例。

context

返回当前前向传递的上下文管理器。

custom_loss

覆盖以自定义用于优化此模型的损失函数。

forward

使用给定的输入张量和状态调用模型。

get_initial_state

获取模型的初始递归状态值。

is_time_major

如果为 True,调用此 ModelV2 的数据必须为时间优先格式。

last_output

返回从调用模型返回的最后一个输出。

metrics

覆盖以从您的模型返回自定义指标。

trainable_variables

返回此模型的可训练变量列表。

value_function

返回最近一次前向传递的值函数输出。

variables

返回此模型的变量列表(或字典)。