ray.rllib.policy.torch_policy_v2.TorchPolicyV2#

class ray.rllib.policy.torch_policy_v2.TorchPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, *, max_seq_len: int = 20)[源代码]#

基类:Policy

与 RLlib 一起使用的 PyTorch 特定策略类。

方法

__init__

初始化一个 TorchPolicy 实例。

action_distribution_fn

此策略的动作分布函数。

action_sampler_fn

给定策略,用于采样新动作的自定义函数。

apply

使用此 Policy 实例调用给定的函数。

compute_single_action

计算并返回一个单一(B=1)的动作值。

export_checkpoint

将导出策略检查点保存到本地目录并返回一个 AIR 检查点。

export_model

将策略的模型导出到本地目录以供服务。

extra_action_out

返回包含在经验批次中的额外信息的字典。

extra_compute_grad_fetches

从 compute_gradients() 中获取并返回的额外值。

extra_grad_process

在每次 optimizer.zero_grad() + loss.backward() 调用后被调用。

from_checkpoint

从给定的策略或算法检查点创建新的策略实例。

from_state

从状态对象中恢复策略。

get_batch_divisibility_req

获取批处理可分性请求。

get_connector_metrics

从连接器获取时间指标。

get_exploration_state

返回此策略的探索组件的状态。

get_host

返回计算机的网络名称。

get_session

返回用于计算动作的 tf.Session 对象,或返回 None。

get_tower_stats

返回每个塔的统计列表,复制到此策略的设备中。

import_model_from_h5

将权重导入到 torch 模型中。

init_view_requirements

learn_on_batch()compute_actions 调用的最大视图需求字典。

learn_on_batch_from_replay_buffer

从给定的回放缓存中采样一批数据并执行更新。

loss

构建损失函数。

make_model

创建模型。

make_model_and_action_dist

创建模型和动作分布函数。

make_rl_module

返回 RL 模块(仅在启用 RLModule API 时)。

maybe_add_time_dimension

为循环 RLModules 添加时间维度。

on_global_var_update

在全局变量更新时调用。

optimizer

自定义要使用的本地 PyTorch 优化器。

postprocess_trajectory

对轨迹进行后处理并返回处理后的轨迹。

reset_connectors

重置此策略的操作连接器和代理连接器。

restore_connectors

如果配置可用,则恢复代理和操作连接器。

stats_fn

统计函数。