ray.tune.schedulers.pb2.PB2#
- class ray.tune.schedulers.pb2.PB2(time_attr: str = 'time_total_s', metric: str | None = None, mode: str | None = None, perturbation_interval: float = 60.0, hyperparam_bounds: Dict[str, dict | list | tuple] = None, quantile_fraction: float = 0.25, log_config: bool = True, require_attrs: bool = True, synch: bool = False, custom_explore_fn: Callable[[dict], dict] | None = None)[源代码]#
-
实现了基于种群的Bandit(PB2)算法。
PB2 并行训练一组模型(或代理)。定期地,表现不佳的模型会克隆表现最佳者的状态,并使用 GP-bandit 优化重新选择超参数。GP 模型被训练来预测下一个训练周期的改进。
与PBT类似,PB2在训练期间调整超参数。这使得超参数的发现速度非常快,并且还能自动发现调度。
此 Tune PB2 实现基于 Tune 的 PBT 实现构建。它将所有添加的试验视为 PB2 群体的一部分。如果试验数量超过集群容量,它们将被时间复用以平衡群体中的训练进度。要运行多个试验,请使用
tune.TuneConfig(num_samples=<int>)
。在 {LOG_DIR}/{MY_EXPERIMENT_NAME}/ 中,所有变异都记录在
pb2_global.txt
中,而各个策略扰动则记录在 pb2_policy_{i}.txt 中。Tune 日志:每次扰动步骤记录 [目标试验标签, 克隆试验标签, 目标试验迭代, 克隆试验迭代, 旧配置, 新配置]。- 参数:
time_attr – 用于比较时间的训练结果属性。请注意,您可以传入非时间性的内容,例如
training_iteration
作为进度衡量标准,唯一的要求是该属性应单调递增。metric – 训练结果目标值属性。停止程序将使用此属性。
mode – 其中之一 {min, max}。确定目标是最小化还是最大化指标属性。
perturbation_interval – 模型将在这个
time_attr
的时间间隔内考虑进行扰动。请注意,扰动会产生检查点开销,因此不应设置得过于频繁。hyperparam_bounds – 要变异的超参数。格式如下:对于每个键,输入一个形式为 [min, max] 的列表,表示超参数的最小值和最大值。一个键也可以包含一个字典用于嵌套的超参数。如果在试验的初始
config
中不存在相应的超参数,Tune 将在hyperparam_bounds
提供的边界之间均匀采样以获得初始超参数值。quantile_fraction – 参数从顶部的
quantile_fraction
部分试验转移到底部的quantile_fraction
部分。需要在 0 到 0.5 之间。设置为 0 基本上意味着完全不进行利用。custom_explore_fn – 你也可以指定一个自定义的探索函数。这个函数被调用为
f(config)
,其中输入是由贝叶斯优化生成的新配置。这个函数应返回根据需要更新的config
。log_config – 是否在每次利用时将每个模型的 ray 配置记录到 local_dir。允许重建配置调度。
require_attrs – 是否要求 time_attr 和 metric 在每次迭代的 result 中出现。如果为 True,当这些值在 trial result 中不存在时,将会引发错误。
synch – 如果为 False,将使用 PBT 的异步实现。每个试验独立地在每个 perturbation_interval 发生扰动。如果为 True,将使用 PBT 的同步实现。扰动只会在所有试验在每个 perturbation_interval 同步时发生。默认为 False。请参阅这里的附录 A.1 https://arxiv.org/pdf/1711.09846.pdf。
示例
from ray import tune from ray.tune.schedulers.pb2 import PB2 from ray.tune.examples.pbt_function import pbt_function # run "pip install gpy" to use PB2 pb2 = PB2( metric="mean_accuracy", mode="max", perturbation_interval=20, hyperparam_bounds={"lr": [0.0001, 0.1]}, ) tuner = tune.Tuner( pbt_function, tune_config=tune.TuneConfig( scheduler=pb2, num_samples=8, ), param_space={"lr": 0.0001}, ) tuner.fit()
方法
确保所有试验获得公平的时间份额(由 time_attr 定义)。
从检查点恢复试用调度器。
将试验调度器保存到检查点
属性
继续试验执行的状态
暂停试验执行的状态
停止试验执行的状态