9. 模型持久化#
在训练完一个 scikit-learn 模型后,希望有一种方法可以将模型持久化,以便将来使用而无需重新训练。根据您的使用场景,有几种不同的方式可以持久化 scikit-learn 模型,这里我们将帮助您决定哪种方式最适合您。为了做出决定,您需要回答以下问题:
您是否需要在持久化后保留 Python 对象,还是只需要持久化模型以便提供服务并从中获取预测结果?
如果您只需要提供模型服务,而不需要在 Python 对象本身上进行进一步的调查,那么 ONNX 可能是您的最佳选择。请注意,并非所有模型都受 ONNX 支持。
如果 ONNX 不适合您的使用场景,下一个问题是:
您是否完全信任模型的来源,或者对持久化模型的来源有任何安全顾虑?
如果您有安全顾虑,那么您应该考虑使用 skops.io ,它会将 Python 对象返回给您,但与基于 pickle
的持久化解决方案不同,加载持久化模型不会自动允许任意代码执行。请注意,这需要手动调查持久化文件,而 skops.io
允许您这样做。
其他解决方案假设您完全信任要加载的文件的来源,因为它们在加载持久化文件时都容易受到任意代码执行的影响,因为它们都在底层使用了 pickle 协议。
您是否关心加载模型的性能,以及在进程间共享模型时,磁盘上的内存映射对象是否有益?
如果是,那么您可以考虑使用 joblib 。如果这不是您的主要关注点,那么您可以使用内置的 pickle
模块。
4. 你是否尝试过使用 pickle
或 joblib
却发现模型无法持久化?例如,当你的模型中包含用户定义的函数时,这种情况就可能发生。
如果是这样,你可以使用 cloudpickle ,它能够序列化某些 pickle
或 joblib
无法序列化的对象。
9.1. 工作流程概览#
在一个典型的工作流程中,第一步是使用 scikit-learn 和兼容的库来训练模型。请注意,不同持久化方法对 scikit-learn 和第三方估计器的支持程度各不相同。
9.1.1. 训练并持久化模型#
创建适当的模型取决于你的使用场景。作为一个例子,这里我们在鸢尾花数据集上训练一个 sklearn.ensemble.HistGradientBoostingClassifier
>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()
一旦模型训练完成,你可以使用你选择的方法将其持久化,然后在另一个环境中加载模型并根据输入数据获取预测结果。这里主要有两条路径,取决于你如何持久化并计划提供模型:
ONNX : 你需要一个
ONNX
运行时和一个安装了适当依赖项的环境来加载模型并使用运行时获取预测。这个环境可以非常精简,甚至不一定需要安装 Python 来加载模型和计算预测。另外请注意,onnxruntime
通常比 Python 需要更少的内存来从小型模型中计算预测。skops.io
,pickle
,joblib
, cloudpickle : 你需要一个安装了适当依赖项的 Python 环境来加载模型并从中获取预测。这个环境应该与训练模型的环境相同,
包**和与模型训练环境相同的**版本。请注意,这些方法都不支持加载使用不同版本的scikit-learn训练的模型,以及其他可能不同版本的依赖项,如 numpy
和 scipy
。另一个考虑因素是在不同的硬件上运行持久化的模型,在大多数情况下,您应该能够在不同的硬件上加载持久化的模型。
9.2. ONNX#
ONNX
,或Open Neural Network Exchange <https://onnx.ai/>`__格式最适合在需要持久化模型然后使用持久化文件进行预测而不需要加载Python对象的情况下使用。在服务环境需要精简和最小化的情况下也很有用,因为 `ONNX
运行时不要求python
。
ONNX
是模型的二进制序列化。它是为了提高数据模型互操作表示的可用性而开发的。它的目标是促进数据模型在不同机器学习框架之间的转换,并提高它们在不同计算架构上的可移植性。更多详情可从ONNX教程 <https://onnx.ai/get-started.html>`__获取。要将scikit-learn模型转换为 `ONNX
,已开发了 sklearn-onnx 。然而,并非所有scikit-learn模型都受支持,并且仅限于核心scikit-learn,不支持大多数第三方估计器。可以为第三方或自定义估计器编写自定义转换器,但相关文档稀少,可能具有挑战性。
#使用ONNX
要将模型转换为 ONNX
格式,您需要向转换器提供一些关于输入的信息,更多详情可以阅读 这里:
from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
f.write(onx.SerializeToString())
你可以在 Python 中加载模型,并使用 ONNX
运行时获取预测:
from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
skops.io
skops.io
避免了使用 pickle
,并且只加载文件中包含的类型和函数引用,这些类型和引用要么默认受信任,要么由用户信任。因此,它比 pickle
、joblib
和 cloudpickle 提供了更安全的格式。
#使用 skops
API 与 pickle
非常相似,你可以按照 文档 中解释的那样,使用 skops.io.dump
和 skops.io.dumps
持久化你的模型:
import skops.io as sio
obj = sio.dump(clf, "filename.skops")
你可以使用 skops.io.load
和 skops.io.loads
加载它们回来。然而,你需要指定你信任的类型。你可以使用 skops.io.get_untrusted_types
获取已转储对象/文件中的现有未知类型,并在检查其内容后,将其传递给加载函数:
unknown_types = sio.get_untrusted_types(file="filename.skops")
# 调查 unknown_types 的内容,只有在信任你所看到的一切时才加载。
clf = sio.load("filename.skops", trusted=unknown_types)
请在 skops 问题跟踪器 上报告与此格式相关的问题和功能请求。
pickle
,joblib
, 和cloudpickle
pickle
是 Python 标准库中的一个模块。它可以序列化和反序列化任何 Python 对象,包括自定义的 Python 类和对象。joblib
在处理大型机器学习模型或大型 numpy 数组时比pickle
更高效。cloudpickle 可以序列化某些不能被
pickle
或joblib
序列化的对象,例如用户定义的函数和 lambda 函数。这种情况可能发生在使用FunctionTransformer
并使用自定义函数转换数据时。
#使用 pickle
, joblib
, 或 cloudpickle
根据您的使用场景,您可以选择这三种方法之一来持久化和加载您的 scikit-learn 模型,它们都遵循相同的 API:
# 这里可以用 joblib 或 cloudpickle 替换 pickle
from pickle import dump
with open("filename.pkl", "wb") as f:
dump(clf, f, protocol=5)
推荐使用 protocol=5
以减少内存使用并加快存储和加载任何作为模型拟合属性存储的大型 NumPy 数组。您也可以传递 protocol=pickle.HIGHEST_PROTOCOL
,这在 Python 3.8 及更高版本中(写作时)等同于 protocol=5
。
之后在需要时,您可以从持久化文件中加载相同的对象:
# 这里可以用 joblib 或 cloudpickle 替换 pickle
from pickle import load
with open("filename.pkl", "rb") as f:
clf = load(f)
9.3. 安全性和可维护性限制#
pickle
(以及扩展的 joblib
和 cloudpickle
)存在许多设计上的安全漏洞,仅当制品(即 pickle 文件)来自可信和验证的来源时才应使用。
来源。您永远不应该从不信任的来源加载 pickle 文件,同样地,您也不应该从不信任的来源执行代码。
另请注意,可以使用 ONNX
格式表示任意计算,因此建议在沙盒环境中使用 ONNX
提供模型,以防止计算和内存漏洞。
还要注意,没有支持的方法来加载使用不同版本的 scikit-learn 训练的模型。虽然使用 skops.io
、joblib
、pickle
或 cloudpickle ,使用一个版本的 scikit-learn 保存的模型可能在其他版本中加载,但这完全不受支持且不建议这样做。还应记住,对这些数据执行的操作可能会给出不同的和意外的结果,甚至可能导致您的 Python 进程崩溃。
为了使用未来版本的 scikit-learn 重建类似的模型,应该与 pickle 模型一起保存额外的元数据:
训练数据,例如对不可变快照的引用
用于生成模型的 Python 源代码
scikit-learn 及其依赖项的版本
在训练数据上获得的交叉验证分数
这应该使得检查交叉验证分数是否在同一范围内成为可能。
除了少数例外,假设使用相同版本的依赖项和 Python,持久化的模型应该可以在操作系统和硬件架构之间移植。如果您遇到不可移植的估计器,请在 GitHub 上提出问题。持久化的模型通常使用 Docker 等容器在生产环境中部署,以冻结环境和依赖项。
如果您想了解更多关于这些问题的信息,请参考以下演讲:
Adrin Jalali: Let’s exploit pickle, and skops to the rescue! | PyData Amsterdam 2023 。
`Alex Gaynor: Pickles are for Delis, not Software - PyCon 2014
<https://pyvideo.org/video/2566/pickles-are-for-delis-not-software>`__.
9.3.1. 在生产环境中复制训练环境#
如果依赖项的版本在训练和生产之间可能不同,那么在使用训练好的模型时可能会导致意外行为和错误。为了防止这种情况,建议在训练和生产环境中使用相同的依赖项和版本。这些传递依赖项可以通过 pip
、 mamba
、 conda
、 poetry
、 conda-lock
、 pixi
等包管理工具固定。
并非总是可以在更新后的软件环境中加载使用较旧版本的 scikit-learn 库及其依赖项训练的模型。相反,您可能需要使用所有库的新版本重新训练模型。因此,在训练模型时,记录训练配方(例如 Python 脚本)和训练集信息,以及所有依赖项的元数据非常重要,以便能够自动重建相同的训练环境以进行更新的软件。
#InconsistentVersionWarning
当使用与估计器序列化时版本不一致的 scikit-learn 版本加载估计器时,会引发 InconsistentVersionWarning
。可以捕获此警告以获取估计器序列化时的原始版本:
from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)
try:
with open("model_from_prevision_version.pickle", "rb") as f:
est = pickle.load(f)
except InconsistentVersionWarning as w:
print(w.original_sklearn_version)
9.3.2. 提供模型工件#
训练 scikit-learn 模型的最后一步是提供模型。
一旦训练好的模型成功加载,它就可以用于管理不同的预测请求。这可能涉及将模型作为Web服务使用容器化部署,或其他根据规范的模型部署策略。
9.4. 总结关键点#
根据不同的模型持久化方法,每个方法的关键点可以总结如下:
ONNX
:它为任何机器学习或深度学习模型(除了scikit-learn)提供了一个统一的格式进行持久化,并且对于模型推理(预测)很有用。然而,它可能导致与不同框架的兼容性问题。skops.io
:训练好的scikit-learn模型可以轻松地通过:mod:skops.io
共享和投入生产。与基于:mod:pickle
的替代方法相比,它更安全,因为它不会加载任意代码,除非用户明确要求。此类代码需要在目标Python环境中打包和可导入。joblib
:高效的内存映射技术使得在使用mmap_mode="r"
时,在多个Python进程中使用相同的持久化模型更快。它还提供了压缩和解压缩持久化对象的便捷快捷方式,无需额外代码。然而,从不可信来源加载模型时,它可能触发恶意代码的执行,就像其他基于pickle的持久化机制一样。pickle
:它是Python原生的,大多数Python对象可以使用:mod:pickle
进行序列化和反序列化,包括自定义Python类和函数,只要它们在目标环境中可以导入的包中定义。虽然:mod:pickle
可以用于轻松保存和加载scikit-learn模型,但从不可信来源加载模型时,它可能触发恶意代码的执行。pickle
也可以用于序列化其他Python对象。 如果模型是以protocol=5
持久化的,那么在内存使用上会非常高效,但它不支持内存映射。cloudpickle : 它的加载效率与
pickle
和joblib
(不带内存映射)相当,但提供了额外的灵活性来序列化自定义的 Python 代码,如 lambda 表达式和交互式定义的函数和类。在持久化包含自定义 Python 组件的管道时,它可能是最后的手段,例如包装了在训练脚本本身或更一般地在任何可导入的 Python 包之外定义的函数的sklearn.preprocessing.FunctionTransformer
。请注意, cloudpickle 不提供向前兼容性保证,您可能需要使用相同版本的 cloudpickle 以及定义模型时使用的所有库的相同版本来加载持久化的模型。与其他基于 pickle 的持久化机制一样,从不受信任的来源加载模型时可能会触发恶意代码的执行。