from __future__ import annotations
from enum import Enum
from typing import Callable
from numpy import ndarray
from torch import Tensor
from .util import (
cos_sim,
dot_score,
euclidean_sim,
manhattan_sim,
pairwise_cos_sim,
pairwise_dot_score,
pairwise_euclidean_sim,
pairwise_manhattan_sim,
)
[文档]
class SimilarityFunction(Enum):
"""
Enum class for supported similarity functions. The following functions are supported:
- ``SimilarityFunction.COSINE`` (``"cosine"``): Cosine similarity
- ``SimilarityFunction.DOT_PRODUCT`` (``"dot"``, ``dot_product``): Dot product similarity
- ``SimilarityFunction.EUCLIDEAN`` (``"euclidean"``): Euclidean distance
- ``SimilarityFunction.MANHATTAN`` (``"manhattan"``): Manhattan distance
"""
COSINE = "cosine"
DOT_PRODUCT = "dot"
DOT = "dot" # Alias for DOT_PRODUCT
EUCLIDEAN = "euclidean"
MANHATTAN = "manhattan"
[文档]
@staticmethod
def to_similarity_fn(
similarity_function: str | SimilarityFunction,
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
"""
Converts a similarity function name or enum value to the corresponding similarity function.
Args:
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
Returns:
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The corresponding similarity function.
Raises:
ValueError: If the provided function is not supported.
Example:
>>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine")
>>> similarity_scores = similarity_fn(embeddings1, embeddings2)
>>> similarity_scores
tensor([[0.3952, 0.0554],
[0.0992, 0.1570]])
"""
similarity_function = SimilarityFunction(similarity_function)
if similarity_function == SimilarityFunction.COSINE:
return cos_sim
if similarity_function == SimilarityFunction.DOT_PRODUCT:
return dot_score
if similarity_function == SimilarityFunction.MANHATTAN:
return manhattan_sim
if similarity_function == SimilarityFunction.EUCLIDEAN:
return euclidean_sim
raise ValueError(
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
)
[文档]
@staticmethod
def to_similarity_pairwise_fn(
similarity_function: str | SimilarityFunction,
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
"""
Converts a similarity function into a pairwise similarity function.
The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only
computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than
computing the similarity between all pairs of a and b.
Args:
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
Returns:
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The pairwise similarity function.
Raises:
ValueError: If the provided similarity function is not supported.
Example:
>>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine")
>>> similarity_scores = pairwise_fn(embeddings1, embeddings2)
>>> similarity_scores
tensor([0.3952, 0.1570])
"""
similarity_function = SimilarityFunction(similarity_function)
if similarity_function == SimilarityFunction.COSINE:
return pairwise_cos_sim
if similarity_function == SimilarityFunction.DOT_PRODUCT:
return pairwise_dot_score
if similarity_function == SimilarityFunction.MANHATTAN:
return pairwise_manhattan_sim
if similarity_function == SimilarityFunction.EUCLIDEAN:
return pairwise_euclidean_sim
raise ValueError(
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
)
[文档]
@staticmethod
def possible_values() -> list[str]:
"""
Returns a list of possible values for the SimilarityFunction enum.
Returns:
list: A list of possible values for the SimilarityFunction enum.
Example:
>>> possible_values = SimilarityFunction.possible_values()
>>> possible_values
['cosine', 'dot', 'euclidean', 'manhattan']
"""
return [m.value for m in SimilarityFunction]