ray.rllib.core.learner.learner_group.学习者组#

class ray.rllib.core.learner.learner_group.LearnerGroup(*, config: AlgorithmConfig, module_spec: RLModuleSpec | MultiRLModuleSpec | None = None)[源代码]#

基类:Checkpointable

n(可能是远程的)学习者工作者的协调者。

每个 Learner worker 都有一份 RLModule 的副本、损失函数(一个或多个),以及一个或多个优化器。

PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。

方法

__init__

初始化一个 LearnerGroup 实例。

add_module

将一个模块添加到底层的 MultiRLModule 中。

foreach_learner

对每个学习者 L 调用给定的函数,参数为:(L, **kwargs)。

from_checkpoint

从给定位置创建一个新的 Checkpointable 实例并返回它。

get_metadata

返回可写入的JSON元数据,进一步描述实现类。

get_stats

返回此学习组的输入队列的当前统计信息。

get_weights

便捷方法,替代 self.get_state(components=...)。

remove_module

从学习者中移除一个模块。

restore_from_path

从给定的路径恢复实现类的状态。

save_to_path

将实现类的状态(或 state)保存到 path

set_weights

便捷方法,而不是 self.set_state({'learner': {'rl_module': ..}})。

shutdown

关闭 LearnerGroup。

update_from_batch

基于给定的批次,对学习器执行基于梯度的更新。

update_from_episodes

基于给定的片段,对学习者执行基于梯度的更新。

属性

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

is_local

is_remote