ray.tune.Trainable.load_checkpoint#

Trainable.load_checkpoint(checkpoint: Dict | None)[源代码]#

子类应重写此方法以实现 restore()。

警告

在此方法中,不要依赖绝对路径。在 Trainable.save_checkpoint 中使用的 checkpoint_dir 的绝对路径可能会改变。

如果 Trainable.save_checkpoint 返回一个带前缀的字符串,Trainable.save_checkpoint 返回的检查点字符串的前缀可能会改变。这是因为试验暂停依赖于临时目录。

提供给 Trainable.save_checkpoint 的 checkpoint_dir 下的目录结构被保留。

请参见下面的示例。

示例

>>> import os
>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        my_checkpoint_path = os.path.join(checkpoint_path, "my/path")
...        return my_checkpoint_path
...    def load_checkpoint(self, my_checkpoint_path):
...        print(my_checkpoint_path)
>>> trainer = Example()
>>> # This is used when PAUSED.
>>> checkpoint_result = trainer.save() 
>>> trainer.restore(checkpoint_result) 

如果 Trainable.save_checkpoint 返回一个字典,那么 Tune 将直接将字典数据作为参数传递给此方法。

示例

>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        return {"my_data": 1}
...    def load_checkpoint(self, checkpoint_dict):
...        print(checkpoint_dict["my_data"])

Added in version 0.8.7.

参数:

checkpoint – 如果是字典,返回值与 save_checkpoint 返回的值相同。否则,返回存储检查点的目录。