ray.rllib.policy.policy.Policy#

class ray.rllib.policy.policy.Policy(observation_space: gymnasium.Space, action_space: gymnasium.Space, config: dict)[源代码]#

基类:object

RLlib 所有策略实现的基类。

Policy 是所有特定于 DL-框架子类的抽象超类(例如 TFPolicy 或 TorchPolicy)。它公开了 API 以

  1. 从观察(可能还有其他)输入中计算动作。

  2. 管理策略的 NN 模型,如导出和加载它们的权重。

  3. 通过以下方式对来自环境或其他输入的给定轨迹进行后处理:

    postprocess_trajectory 方法。

  4. 从训练批次中计算损失。

  5. 对NN模型执行来自训练批次的更新(这通常包括损失)

    计算) 要么:

    1. 在一个单一的步骤中(learn_on_batch

    2. 通过批量预加载,然后进行 n 步的实际损失计算和更新

      (load_batch_into_buffer + learn_on_loaded_batch).

方法

__init__

初始化一个策略实例。

apply

使用此 Policy 实例调用给定的函数。

apply_gradients

应用(之前)计算的梯度。

compute_actions

计算当前策略的动作。

compute_actions_from_input_dict

计算从收集的样本(跨多个代理)中得出的动作。

compute_gradients

在给定一批经验的情况下计算梯度。

compute_log_likelihoods

计算给定动作和观察的对数概率/似然。

compute_single_action

计算并返回一个单一(B=1)的动作值。

export_checkpoint

将导出策略检查点保存到本地目录并返回一个 AIR 检查点。

export_model

将策略的模型导出到本地目录以供服务。

from_checkpoint

从给定的策略或算法检查点创建新的策略实例。

from_state

从状态对象中恢复策略。

get_connector_metrics

从连接器获取时间指标。

get_exploration_state

返回此策略的探索组件的状态。

get_host

返回计算机的网络名称。

get_initial_state

返回当前策略的初始RNN状态。

get_num_samples_loaded_into_buffer

返回给定缓冲区中当前加载的样本数量。

get_session

返回用于计算动作的 tf.Session 对象,或返回 None。

get_state

返回此策略的当前整个状态。

get_weights

返回模型权重。

import_model_from_h5

从本地文件导入策略。

init_view_requirements

learn_on_batch()compute_actions 调用的最大视图需求字典。

is_recurrent

此策略是否持有循环模型。

learn_on_batch

执行一次学习更新,基于 samples

learn_on_batch_from_replay_buffer

从给定的回放缓存中采样一批数据并执行更新。

learn_on_loaded_batch

在已经加载到缓冲区中的数据上运行一次SGD的单步操作。

load_batch_into_buffer

将给定的 SampleBatch 批量加载到设备的内存中。

loss

此策略的损失函数。

make_rl_module

返回 RL 模块(仅在启用 RLModule API 时)。

maybe_add_time_dimension

为循环 RLModules 添加时间维度。

maybe_remove_time_dimension

移除循环 RLModules 的时间维度。

num_state_tensors

策略的RNN模型所需的内部状态数量。

on_global_var_update

在全局变量更新时调用。

postprocess_trajectory

实现特定算法的轨迹后处理。

reset_connectors

重置此策略的操作连接器和代理连接器。

restore_connectors

如果配置可用,则恢复代理和操作连接器。

set_state

state 恢复此策略的当前整个状态。

set_weights

设置此策略模型的权重。