语义文本相似性

对于语义文本相似性(STS),我们希望为所有相关文本生成嵌入,并计算它们之间的相似性。相似度得分最高的文本对在语义上最为相似。更多关于获取嵌入分数的高级细节,请参阅 计算嵌入 文档。

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

# Two lists of sentences
sentences1 = [
    "The new movie is awesome",
    "The cat sits outside",
    "A man is playing guitar",
]

sentences2 = [
    "The dog plays in the garden",
    "The new movie is so great",
    "A woman watches TV",
]

# Compute embeddings for both lists
embeddings1 = model.encode(sentences1)
embeddings2 = model.encode(sentences2)

# Compute cosine similarities
similarities = model.similarity(embeddings1, embeddings2)

# Output the pairs with their score
for idx_i, sentence1 in enumerate(sentences1):
    print(sentence1)
    for idx_j, sentence2 in enumerate(sentences2):
        print(f" - {sentence2: <30}: {similarities[idx_i][idx_j]:.4f}")
The new movie is awesome
- The dog plays in the garden   : 0.0543
- The new movie is so great     : 0.8939
- A woman watches TV            : -0.0502
The cat sits outside
- The dog plays in the garden   : 0.2838
- The new movie is so great     : -0.0029
- A woman watches TV            : 0.1310
A man is playing guitar
- The dog plays in the garden   : 0.2277
- The new movie is so great     : -0.0136
- A woman watches TV            : -0.0327

在这个例子中,SentenceTransformer.similarity 方法返回一个 3x3 的矩阵,其中包含了 embeddings1embeddings2 之间所有可能对的余弦相似度分数。

相似度计算

使用的相似度度量存储在 SentenceTransformer 实例的 SentenceTransformer.similarity_fn_name 属性下。有效选项包括:

  • SimilarityFunction.COSINE (又名 "cosine"): 余弦相似度 (默认)

  • SimilarityFunction.DOT_PRODUCT (又名 "dot"): 点积

  • SimilarityFunction.EUCLIDEAN (又名 "euclidean"): 负欧几里得距离

  • SimilarityFunction.MANHATTAN (又名 "manhattan"): 负曼哈顿距离

这个值可以通过几种方式进行更改:

  1. 通过使用所需的相似度函数初始化 SentenceTransformer 实例:

    from sentence_transformers import SentenceTransformer, SimilarityFunction
    
    model = SentenceTransformer("all-MiniLM-L6-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT)
    
  2. 通过直接在 SentenceTransformer 实例上设置值:

    from sentence_transformers import SentenceTransformer, SimilarityFunction
    
    model = SentenceTransformer("all-MiniLM-L6-v2")
    model.similarity_fn_name = SimilarityFunction.DOT_PRODUCT
    
  3. 通过在保存模型的 config_sentence_transformers.json 文件中设置 "similarity_fn_name" 键下的值。当你保存一个 Sentence Transformer 模型时,该值也会自动保存。

Sentence Transformers 实现了两种计算嵌入之间相似性的方法:

  • SentenceTransformer.similarity: 计算所有嵌入对之间的相似度。

  • SentenceTransformer.pairwise_similarity: 以成对的方式计算嵌入之间的相似度。

from sentence_transformers import SentenceTransformer, SimilarityFunction

# Load a pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Embed some sentences
sentences = [
    "The weather is lovely today.",
    "It's so sunny outside!",
    "He drove to the stadium.",
]
embeddings = model.encode(sentences)

similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.6660, 0.1046],
#         [0.6660, 1.0000, 0.1411],
#         [0.1046, 0.1411, 1.0000]])

# Change the similarity function to Manhattan distance
model.similarity_fn_name = SimilarityFunction.MANHATTAN
print(model.similarity_fn_name)
# => "manhattan"

similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[ -0.0000, -12.6269, -20.2167],
#         [-12.6269,  -0.0000, -20.1288],
#         [-20.2167, -20.1288,  -0.0000]])

备注

如果一个 Sentence Transformer 实例以 Normalize 模块结束,那么选择“点”度量而不是“余弦”度量是合理的。

在归一化的嵌入向量上进行点积等价于余弦相似度,但“余弦”会再次对嵌入向量进行归一化。因此,“点积”度量将比“余弦”更快。

如果你想在一个长句子列表中找到得分最高的配对,可以查看 Paraphrase Mining