检查训练结果#

你的 Trainer.fit() 调用的返回值是一个 Result 对象。

Result 对象包含,除其他信息外:

  • 最后报告的指标(例如损失)

  • 最后一个报告的检查点(用于加载模型)

  • 错误信息,如果发生任何错误

查看指标#

你可以从 Result 对象中检索报告给 Ray Train 的指标。

常见的指标包括训练或验证损失,或预测准确率。

Result 对象中检索的指标对应于您作为参数传递给 train.report 的那些指标 在您的训练函数中

最后报告的指标#

使用 Result.metrics 来检索最新报告的指标。

result = trainer.fit()

print("Observed metrics:", result.metrics)

所有报告指标的数据框#

使用 Result.metrics_dataframe 来检索所有报告指标的 pandas DataFrame。

df = result.metrics_dataframe
print("Minimum loss", min(df["loss"]))

检索检查点#

你可以从 Result 对象中检索报告给 Ray Train 的检查点。

检查点 包含了恢复训练状态所需的所有信息。这通常包括训练好的模型。

你可以使用检查点进行常见的下游任务,例如 使用 Ray Data 进行离线批量推理使用 Ray Serve 进行在线模型服务

Result 对象中检索到的检查点对应于您作为参数传递给 train.report 的那些 在您的训练函数中

最后保存的检查点#

使用 Result.checkpoint 来检索最后一个检查点。

print("Last checkpoint:", result.checkpoint)

with result.checkpoint.as_directory() as tmpdir:
    # Load model from directory
    ...

其他检查点#

有时你想访问一个更早的检查点。例如,如果你的损失在更多训练后由于过拟合而增加,你可能想要检索损失最小的检查点。

你可以通过 Result.best_checkpoints 获取所有可用检查点及其指标的列表。

# Print available checkpoints
for checkpoint, metrics in result.best_checkpoints:
    print("Loss", metrics["loss"], "checkpoint", checkpoint)

# Get checkpoint with minimal loss
best_checkpoint = min(
    result.best_checkpoints, key=lambda checkpoint: checkpoint[1]["loss"]
)[0]

with best_checkpoint.as_directory() as tmpdir:
    # Load model from directory
    ...

参见

有关检查点的更多信息,请参阅 保存和加载检查点

访问存储位置#

如果你需要稍后检索结果,你可以使用 Result.path 获取训练运行的存储位置。

此路径将对应于您在 RunConfig 中配置的 storage_path。它将是该路径内的一个(嵌套的)子目录,通常形式为 TrainerName_date-string/TrainerName_id_00000_0_...

结果还包含一个 pyarrow.fs.FileSystem,可用于访问存储位置,这在路径位于云存储时非常有用。

result_path: str = result.path
result_filesystem: pyarrow.fs.FileSystem = result.filesystem

print(f"Results location (fs, path) = ({result_filesystem}, {result_path})")

您可以使用 Result.from_path 恢复结果:

from ray.train import Result

restored_result = Result.from_path(result_path)
print("Restored loss", result.metrics["loss"])

查看错误#

如果在训练过程中发生错误,Result.error 将被设置并包含引发的异常。

if result.error:
    assert isinstance(result.error, Exception)

    print("Got exception:", result.error)

在持久存储中查找结果#

所有训练结果,包括报告的指标、检查点和错误文件,都存储在配置的 持久存储 中。

请参阅 我们的持久存储指南 以配置您的训练运行的此位置。