sentence_transformers.similarity_functions 源代码

from __future__ import annotations

from enum import Enum
from typing import Callable

from numpy import ndarray
from torch import Tensor

from .util import (
    cos_sim,
    dot_score,
    euclidean_sim,
    manhattan_sim,
    pairwise_cos_sim,
    pairwise_dot_score,
    pairwise_euclidean_sim,
    pairwise_manhattan_sim,
)


[文档] class SimilarityFunction(Enum): """ Enum class for supported similarity functions. The following functions are supported: - ``SimilarityFunction.COSINE`` (``"cosine"``): Cosine similarity - ``SimilarityFunction.DOT_PRODUCT`` (``"dot"``, ``dot_product``): Dot product similarity - ``SimilarityFunction.EUCLIDEAN`` (``"euclidean"``): Euclidean distance - ``SimilarityFunction.MANHATTAN`` (``"manhattan"``): Manhattan distance """ COSINE = "cosine" DOT_PRODUCT = "dot" DOT = "dot" # Alias for DOT_PRODUCT EUCLIDEAN = "euclidean" MANHATTAN = "manhattan"
[文档] @staticmethod def to_similarity_fn( similarity_function: str | SimilarityFunction, ) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]: """ Converts a similarity function name or enum value to the corresponding similarity function. Args: similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function. Returns: Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The corresponding similarity function. Raises: ValueError: If the provided function is not supported. Example: >>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine") >>> similarity_scores = similarity_fn(embeddings1, embeddings2) >>> similarity_scores tensor([[0.3952, 0.0554], [0.0992, 0.1570]]) """ similarity_function = SimilarityFunction(similarity_function) if similarity_function == SimilarityFunction.COSINE: return cos_sim if similarity_function == SimilarityFunction.DOT_PRODUCT: return dot_score if similarity_function == SimilarityFunction.MANHATTAN: return manhattan_sim if similarity_function == SimilarityFunction.EUCLIDEAN: return euclidean_sim raise ValueError( f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}." )
[文档] @staticmethod def to_similarity_pairwise_fn( similarity_function: str | SimilarityFunction, ) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]: """ Converts a similarity function into a pairwise similarity function. The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than computing the similarity between all pairs of a and b. Args: similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function. Returns: Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The pairwise similarity function. Raises: ValueError: If the provided similarity function is not supported. Example: >>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine") >>> similarity_scores = pairwise_fn(embeddings1, embeddings2) >>> similarity_scores tensor([0.3952, 0.1570]) """ similarity_function = SimilarityFunction(similarity_function) if similarity_function == SimilarityFunction.COSINE: return pairwise_cos_sim if similarity_function == SimilarityFunction.DOT_PRODUCT: return pairwise_dot_score if similarity_function == SimilarityFunction.MANHATTAN: return pairwise_manhattan_sim if similarity_function == SimilarityFunction.EUCLIDEAN: return pairwise_euclidean_sim raise ValueError( f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}." )
[文档] @staticmethod def possible_values() -> list[str]: """ Returns a list of possible values for the SimilarityFunction enum. Returns: list: A list of possible values for the SimilarityFunction enum. Example: >>> possible_values = SimilarityFunction.possible_values() >>> possible_values ['cosine', 'dot', 'euclidean', 'manhattan'] """ return [m.value for m in SimilarityFunction]