ray.tune.Trainable.load_checkpoint#
- Trainable.load_checkpoint(checkpoint: Dict | None)[源代码]#
子类应重写此方法以实现 restore()。
警告
在此方法中,不要依赖绝对路径。在
Trainable.save_checkpoint
中使用的 checkpoint_dir 的绝对路径可能会改变。如果
Trainable.save_checkpoint
返回一个带前缀的字符串,Trainable.save_checkpoint
返回的检查点字符串的前缀可能会改变。这是因为试验暂停依赖于临时目录。提供给
Trainable.save_checkpoint
的 checkpoint_dir 下的目录结构被保留。请参见下面的示例。
示例
>>> import os >>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): ... my_checkpoint_path = os.path.join(checkpoint_path, "my/path") ... return my_checkpoint_path ... def load_checkpoint(self, my_checkpoint_path): ... print(my_checkpoint_path) >>> trainer = Example() >>> # This is used when PAUSED. >>> checkpoint_result = trainer.save() >>> trainer.restore(checkpoint_result)
如果
Trainable.save_checkpoint
返回一个字典,那么 Tune 将直接将字典数据作为参数传递给此方法。示例
>>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): ... return {"my_data": 1} ... def load_checkpoint(self, checkpoint_dict): ... print(checkpoint_dict["my_data"])
Added in version 0.8.7.
- 参数:
checkpoint – 如果是字典,返回值与
save_checkpoint
返回的值相同。否则,返回存储检查点的目录。