ray.rllib.policy.sample_batch.样本批次#

class ray.rllib.policy.sample_batch.SampleBatch(*args, **kwargs)[源代码]#

基类:dict

围绕字典的包装器,键为字符串,值为类似数组的对象。

例如,{“obs”: [1, 2, 3], “reward”: [0, -1, 1]} 是一个包含三个样本的批次,每个样本都有一个 “obs” 和 “reward” 属性。

方法

__init__

构建一个样本批次(与字典构造函数的参数相同)。

agent_steps

返回与 len(self) 相同的结果(此批次中的步数)。

as_multi_agent

返回相应的 MultiAgentBatch

clear

columns

返回指定列中的批处理数据列表。

compress

就地压缩数据缓冲区(按列)。

concat

other 连接到当前对象并返回一个新的 SampleBatch。

copy

创建此 SampleBatch 的深拷贝或浅拷贝并返回。

decompress_if_needed

原地解压缩数据缓冲区(如果不是压缩的,则按列解压缩)。

env_steps

返回与 len(self) 相同的结果(此批次中的步数)。

fromkeys

使用可迭代对象中的键创建一个新字典,并将值设置为指定的值。

get

返回数据中按键指定的一列,或返回默认值。

get_single_step_input_dict

在给定的索引处从 self 创建单个 ts SampleBatch。

is_single_trajectory

如果这个 SampleBatch 只包含一个轨迹,则返回 True。

is_terminated_or_truncated

如果 self 在 idx -1 处被终止或截断,则返回 True。

items

keys

pop

如果未找到键,则返回给定的默认值;否则,引发 KeyError。

popitem

移除并返回一个 (键, 值) 对作为 2-tuple。

right_zero_pad

Right (在末尾添加零) 就地对 SampleBatch 进行零填充。

rows

返回一个数据行的迭代器,即包含列值的字典。

set_get_interceptor

设置一个函数在每次 getitem 时被调用。

set_training

设置此 SampleBatch 的 is_training 标志。

setdefault

如果字典中不存在键,则插入键并赋予默认值。

shuffle

就地打乱此批次中的行。

size_bytes

返回所有数据缓冲区字节数的总和。

slice

返回此批次行数据的切片(不复制)。

split_by_episode

eps_id 列分割并返回新批次列表。

timeslices

返回 SampleBatches,每个代表此数据集的一个 k-切片。

to_device

TODO: 将批处理转移到指定设备作为框架张量。

update

如果 E 存在且有 .keys() 方法,则执行: for k in E: D[k] = E[k] 如果 E 存在但没有 .keys() 方法,则执行: for k, v in E: D[k] = v 在任何一种情况下,之后都会执行: for k in F: D[k] = F[k]

values

属性

ACTIONS

ACTION_DIST

ACTION_DIST_INPUTS

ACTION_LOGP

ACTION_PROB

AGENT_INDEX

ATTENTION_MASKS

CUR_OBS

DONES

ENV_ID

EPS_ID

INFOS

NEXT_OBS

OBS

OBS_EMBEDS

PREV_ACTIONS

PREV_REWARDS

RETURNS_TO_GO

REWARDS

SEQ_LENS

T

TERMINATEDS

TRUNCATEDS

UNROLL_ID

VALUES_BOOTSTRAPPED

VF_PREDS

is_training