Sentence Transformers 在 MLflow 中

注意

sentence_transformers 风格正在积极开发中,并被标记为实验性。公共API可能会发生变化,随着功能的增加,可能会添加新功能。

sentence_transformers 模型风格通过 mlflow.sentence_transformers.save_model()mlflow.sentence_transformers.log_model() 函数,实现了 sentence-transformers 模型 在 MLflow 格式中的记录。使用这些函数还会将 python_function 风格添加到 MLflow 模型中,使得模型可以通过 mlflow.pyfunc.load_model() 作为通用 Python 函数进行推理。此外,mlflow.sentence_transformers.load_model() 可以用于加载以 sentence_transformers 风格保存或记录的 MLflow 模型,并以原生的 sentence-transformers 格式加载。

Sentence Transformers 教程

想要直接进入一些可用的示例和教程,展示如何利用这个库与MLflow结合使用?

See the Tutorials

PyFunc 的输入和输出类型

sentence_transformers python_function (pyfunc) 模型风格 标准化了嵌入句子和计算语义相似性的过程。这种标准化允许通过将 sentence_transformers 所需的数据结构适配为与 JSON 序列化和转换为 Pandas DataFrames 兼容的格式,来进行服务和批量推理。

备注

sentence_transformers 风格支持多种模型用于嵌入生成、语义相似性和释义挖掘等任务。具体的输入和输出类型将取决于所使用的模型和任务。

保存和记录句子转换器模型

你可以在 MLflow 中保存和记录 sentence-transformers 模型。以下是保存和记录模型的示例:

import mlflow
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model_name")

# Saving the model
mlflow.sentence_transformers.save_model(model=model, path="path/to/save/directory")

# Logging the model
with mlflow.start_run():
    mlflow.sentence_transformers.log_model(
        sentence_transformers_model=model, artifact_path="model_artifact_path"
    )

使用与 OpenAI 兼容的推理接口保存 Sentence Transformers 模型

备注

此功能仅在 MLflow 2.11.0 及以上版本中可用。

MLflow 的 sentence_transformers 风格允许你在使用 mlflow.sentence_transformers.save_model()mlflow.sentence_transformers.log_model() 保存模型时,传入 task 参数,其字符串值为 "llm/v1/embeddings"

例如:

import mlflow
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

mlflow.sentence_transformers.save_model(
    model=model, path="path/to/save/directory", task="llm/v1/embeddings"
)

task 设置为 "llm/v1/embeddings" 时,MLflow 会为你处理以下内容:

  • 为模型设置一个兼容嵌入的签名

  • 执行数据预处理和后处理,以确保输入和输出符合 嵌入API规范 ,该规范与OpenAI的API规范兼容。

请注意,这些修改仅在模型通过 mlflow.pyfunc.load_model() 加载时适用(例如,当使用 mlflow models serve CLI 工具提供模型时)。如果你想只加载基础管道,你可以随时通过 mlflow.sentence_transformers.load_model() 来实现。

除了 sentence-transformers 风格外,transformers 风格还支持 OpenAI 兼容的推理接口("llm/v1/chat""llm/v1/completions")。更多信息请参阅 Transformers 风格指南

自定义 Python 函数实现

除了使用预构建的模型外,您还可以使用 sentence_transformers 风格创建自定义 Python 函数。以下是一个用于比较文本文档相似性的自定义实现的示例:

import mlflow
from mlflow.pyfunc import PythonModel
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util


class DocumentSimilarityModel(PythonModel):
    def load_context(self, context):
        """Load the model context for inference."""
        self.model = SentenceTransformer.load(context.artifacts["model_path"])

    def predict(self, context, model_input):
        """Predict method for comparing similarity between documents."""
        if isinstance(model_input, pd.DataFrame) and model_input.shape[1] == 2:
            documents = model_input.values
        else:
            raise ValueError("Input must be a DataFrame with exactly two columns.")

        # Compute embeddings for each document separately
        embeddings1 = self.model.encode(documents[:, 0], convert_to_tensor=True)
        embeddings2 = self.model.encode(documents[:, 1], convert_to_tensor=True)

        # Calculate cosine similarity
        similarity_scores = util.cos_sim(embeddings1, embeddings2)

        return pd.DataFrame(similarity_scores.numpy(), columns=["similarity_score"])


# Example model saving and loading
model = SentenceTransformer("all-MiniLM-L6-v2")
model_path = "/tmp/sentence_transformers_model"
model.save(model_path)

# Example usage
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="document_similarity_model",
        python_model=DocumentSimilarityModel(),
        artifacts={"model_path": model_path},
    )

loaded = mlflow.pyfunc.load_model(model_info.model_uri)

# Test prediction
df = pd.DataFrame(
    {
        "doc1": ["Sentence Transformers is a wonderful package!"],
        "doc2": ["MLflow is pretty great too!"],
    }
)

result = loaded.predict(df)
print(result)

这将生成传递的文档的相似性分数,如下所示:

   similarity_score
0          0.275423