ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2#

class ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, **kwargs)[源代码]#

基类:Policy

基于 TF-eager / TF2 的 TensorFlow 策略。

此类旨在通过子类化来使用和扩展。

方法

action_distribution_fn

此策略的动作分布函数。

action_sampler_fn

给定策略,用于采样新动作的自定义函数。

apply

使用此 Policy 实例调用给定的函数。

apply_gradients_fn

梯度计算函数(从损失张量,使用本地优化器)。

compute_gradients_fn

梯度计算函数(从损失张量,使用本地优化器)。

compute_single_action

计算并返回一个单一(B=1)的动作值。

export_checkpoint

将导出策略检查点保存到本地目录并返回一个 AIR 检查点。

extra_action_out_fn

从 compute_actions() 中获取并返回的额外值。

extra_learn_fetches_fn

在梯度计算后要报告的额外统计数据。

from_checkpoint

从给定的策略或算法检查点创建新的策略实例。

from_state

从状态对象中恢复策略。

get_batch_divisibility_req

获取批处理可分性请求。

get_connector_metrics

从连接器获取时间指标。

get_host

返回计算机的网络名称。

get_num_samples_loaded_into_buffer

返回给定缓冲区中当前加载的样本数量。

get_session

返回用于计算动作的 tf.Session 对象,或返回 None。

grad_stats_fn

梯度统计函数。

import_model_from_h5

从本地文件导入策略。

init_view_requirements

learn_on_batch()compute_actions 调用的最大视图需求字典。

learn_on_batch_from_replay_buffer

从给定的回放缓存中采样一批数据并执行更新。

learn_on_loaded_batch

在已经加载到缓冲区中的数据上运行一次SGD的单步操作。

load_batch_into_buffer

将给定的 SampleBatch 批量加载到设备的内存中。

loss

使用模型、dist_class 和 train_batch 计算此策略的损失。

make_model

构建此策略的基础模型。

make_rl_module

返回 RL 模块(仅在启用 RLModule API 时)。

maybe_add_time_dimension

为循环 RLModules 添加时间维度。

on_global_var_update

在全局变量更新时调用。

optimizer

用于策略优化的TF优化器。

postprocess_trajectory

以 SampleBatch 格式进行轨迹后处理。

reset_connectors

重置此策略的操作连接器和代理连接器。

restore_connectors

如果配置可用,则恢复代理和操作连接器。

stats_fn

统计函数。

variables

返回此策略中所有可保存变量的列表。