检查点进度
Metaflow 工件用于在任务完成时持久化模型、数据框和其他 Python 对象。它们在步骤边界检查流的状态,使您能够使用 Client API 检查任务的结果并 resume 从任意步骤 执行。
在某些情况下,一个任务可能需要很长时间才能执行。例如,在昂贵的GPU实例上训练模型(或在集群上)可能需要几个小时甚至几天。在这种情况下,仅在任务完成后保存最终模型是不够的。相反,建议在任务执行期间定期检查点进度,以便在发生故障时不会丢失数小时的工作。
您可以使用 Metaflow 扩展,metaflow-checkpoint,轻松创建和使用任务中的检查点:只需添加
@checkpoint 并调用 current.checkpoint.save 定期检查进度。@checkpoint 的一个主要好处是它会自动将检查点与 Metaflow 任务一起组织,因此您不需要手动处理保存、
加载、组织和跟踪检查点文件。
值得注意的是, @checkpoint 与流行的AI和ML框架如XGBoost、PyTorch等无缝集成,如下所述。有关更多背景,请阅读 关于@checkpoint的公告博客文章。
这个 @checkpoint 装饰器还不是 Metaflow 核心的内置部分,因此您需要按照下面的说明单独安装它。此外,它的 API 未来可能会改变,与核心 Metaflow 的 API 不同,后者保证向后兼容。请在 Metaflow Slack 上分享您的反馈!
安装 @checkpoint
要使用@checkpoint扩展,使用以下命令安装
pip install metaflow-checkpoint
在您开发和部署 Metaflow 代码的环境中。Metaflow 自动打包用于远程执行的扩展,因此您无需将其包含在用于远程运行任务的容器镜像中。
使用 @checkpoint
这个@checkpoint装饰器通过将文件保留在本地目录中到Metaflow数据存储中来操作。这使它与许多流行的支持本地保留检查点的ML和AI框架直接兼容。
让我们通过这个简单的流程演示功能,该流程试图在一个循环中增加计数器,但有20%的几率会失败。由于 @checkpoint 和 @retry,flaky_count 步骤能够从异常中恢复,并从最新的检查点继续计数,最终成功:
import os
import random
from metaflow import FlowSpec, current, step, retry, checkpoint
class CheckpointCounterFlow(FlowSpec):
@step
def start(self):
self.counter = 0
self.next(self.flaky_count)
@checkpoint
@retry
@step
def flaky_count(self):
cp_path = os.path.join(current.checkpoint.directory, "counter")
def _save_counter():
print(f"Checkpointing counter value {self.counter}")
with open(cp_path, "w") as f:
f.write(str(self.counter))
self.latest_checkpoint = current.checkpoint.save()
def _load_counter():
if current.checkpoint.is_loaded:
with open(cp_path) as f:
self.counter = int(f.read())
print(f"Checkpoint loaded!")
_load_counter()
print("Counter is now", self.counter)
while self.counter < 10:
self.counter += 1
if self.counter % 2 == 0:
_save_counter()
if random.random() < 0.2:
raise Exception("Bad luck! Try again!")
self.next(self.end)
@step
def end(self):
print("Final counter", self.counter)
if __name__ == "__main__":
CheckpointCounterFlow()
安装 metaflow-checkpoint 扩展后,您可以像往常一样运行流程:
python checkpoint_counter.py run
该流程演示了@checkpoint的典型用法:
@checkpoint初始化一个临时目录,current.checkpoint.directory,您可以将其用作待检查点数据的暂存区。默认情况下,
@checkpoint自动加载目录中最新的特定任务检查点。如果找到检查点,则current.checkpoint.is_loaded被设置为True,这样您就可以使用之前存储的数据初始化处理,例如在这种情况下加载counter的最新值。在处理过程中,您可以定期将任何需要的数据保存在暂存目录中,并调用
current.checkpoint.save()将其持久化到数据存储中。我们在一个工件中保存对最新检查点的引用,
latest_checkpoint,这使我们能够稍后找到并加载特定的检查点,如本文件后面所述。
在幕后,除了有效地加载和存储数据, @checkpoint 负责将检查点数据限定在特定任务中。您可以在许多并行任务中使用 @checkpoint,甚至在 foreach 中,知道 @checkpoint 将自动加载特定于每个分支的检查点。它还使得在运行之间使用检查点成为可能,如决定使用哪个检查点中所述。
通过卡片观察 @checkpoint
尝试使用默认的Metaflow
@card运行上述流程:
python checkpoint_counter.py run --with card
如果一个带有 @checkpoint 装饰的步骤启用了卡片,它将添加有关在卡片中加载和存储的检查点的信息。例如,下面的截图显示了与第二次尝试相关联的卡片(卡片顶部的 [Attempt: 1]),它加载了第一次尝试生成的检查点,并在每2秒的间隔存储了四个检查点:
