"""基于RapidFuzz库的字符串距离评估器。"""
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
from langchain.schema import RUN_KEY
def _load_rapidfuzz() -> Any:
"""加载RapidFuzz库。
引发:
ImportError: 如果未安装rapidfuzz库。
返回:
任意: rapidfuzz.distance模块。
"""
try:
import rapidfuzz
except ImportError:
raise ImportError(
"Please install the rapidfuzz library to use the FuzzyMatchStringEvaluator."
"Please install it with `pip install rapidfuzz`."
)
return rapidfuzz.distance
[docs]class StringDistance(str, Enum):
"""距离度量标准。
属性:
DAMERAU_LEVENSHTEIN:Damerau-Levenshtein距离。
LEVENSHTEIN:Levenshtein距离。
JARO:Jaro距离。
JARO_WINKLER:Jaro-Winkler距离。
HAMMING:Hamming距离。
INDEL:Indel距离。"""
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
LEVENSHTEIN = "levenshtein"
JARO = "jaro"
JARO_WINKLER = "jaro_winkler"
HAMMING = "hamming"
INDEL = "indel"
class _RapidFuzzChainMixin(Chain):
"""用于快速fuzz字符串距离评估器的共享方法。"""
distance: StringDistance = Field(default=StringDistance.JARO_WINKLER)
normalize_score: bool = Field(default=True)
"""是否将分数归一化为0到1之间的值。
仅适用于Levenshtein和Damerau-Levenshtein距离。"""
@root_validator
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""验证 rapidfuzz 库是否已安装。
参数:
values(Dict[str, Any]):输入数值。
返回:
Dict[str, Any]:经过验证的数值。
"""
_load_rapidfuzz()
return values
@property
def output_keys(self) -> List[str]:
"""获取输出键。
返回:
List[str]:输出键。
"""
return ["score"]
def _prepare_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""准备输出字典。
参数:
result (Dict[str, Any]): 评估结果。
返回:
Dict[str, Any]: 准备好的输出字典。
"""
result = {"score": result["score"]}
if RUN_KEY in result:
result[RUN_KEY] = result[RUN_KEY].dict()
return result
@staticmethod
def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
"""根据距离类型获取距离度量函数。
参数:
distance (str): 距离类型。
返回:
Callable: 距离度量函数。
引发:
ValueError: 如果距离度量无效。
"""
from rapidfuzz import distance as rf_distance
module_map: Dict[str, Any] = {
StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
StringDistance.JARO: rf_distance.Jaro,
StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
StringDistance.HAMMING: rf_distance.Hamming,
StringDistance.INDEL: rf_distance.Indel,
}
if distance not in module_map:
raise ValueError(
f"Invalid distance metric: {distance}"
f"\nMust be one of: {list(StringDistance)}"
)
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else:
return module.distance
@property
def metric(self) -> Callable:
"""获取距离度量函数。
返回:
Callable:距离度量函数。
"""
return _RapidFuzzChainMixin._get_metric(
self.distance, normalize_score=self.normalize_score
)
def compute_metric(self, a: str, b: str) -> float:
"""计算两个字符串之间的距离。
参数:
a (str): 第一个字符串。
b (str): 第二个字符串。
返回:
float: 两个字符串之间的距离。
"""
return self.metric(a, b)
[docs]class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""计算预测值和参考值之间的字符串距离。
示例
----------
>>> from langchain.evaluation import StringDistanceEvalChain
>>> evaluator = StringDistanceEvalChain()
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="Mindy is the CEO",
)
使用`load_evaluator`函数:
>>> from langchain.evaluation import load_evaluator
>>> evaluator = load_evaluator("string_distance")
>>> evaluator.evaluate_strings(
prediction="The answer is three",
reference="three",
)
"""
@property
def requires_input(self) -> bool:
"""
这个评估器不需要输入。
"""
return False
@property
def requires_reference(self) -> bool:
"""
这个评估器不需要参考。
"""
return True
@property
def input_keys(self) -> List[str]:
"""获取输入键。
返回:
List[str]: 输入键。
"""
return ["reference", "prediction"]
@property
def evaluation_name(self) -> str:
"""获取评估名称。
返回:
str:评估名称。
"""
return f"{self.distance.value}_distance"
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""计算预测值和参考值之间的字符串距离。
参数:
inputs(Dict[str, Any]):输入数值。
run_manager(Optional[CallbackManagerForChainRun]):
回调管理器。
返回:
Dict[str, Any]:包含分数的评估结果。
"""
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""异步计算预测值和参考值之间的字符串距离。
参数:
inputs(Dict[str, Any]):输入数值。
run_manager(Optional[AsyncCallbackManagerForChainRun]):
回调管理器。
返回:
Dict[str, Any]:包含分数的评估结果。
"""
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
def _evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""评估预测值和参考值之间的字符串距离。
参数:
prediction (str): 预测字符串。
reference (Optional[str], optional): 参考字符串。
input (Optional[str], optional): 输入字符串。
callbacks (Callbacks, optional): 要使用的回调函数。
**kwargs: 附加的关键字参数。
返回:
dict: 包含得分的评估结果。
"""
result = self(
inputs={"prediction": prediction, "reference": reference},
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)
async def _aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""异步评估预测值和参考值之间的字符串距离。
参数:
prediction (str): 预测字符串。
reference (Optional[str], optional): 参考字符串。
input (Optional[str], optional): 输入字符串。
callbacks (Callbacks, optional): 要使用的回调函数。
**kwargs: 附加的关键字参数。
返回:
dict: 包含分数的评估结果。
"""
result = await self.acall(
inputs={"prediction": prediction, "reference": reference},
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)
[docs]class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMixin):
"""计算两个预测之间的字符串编辑距离。"""
@property
def input_keys(self) -> List[str]:
"""获取输入键。
返回:
List[str]: 输入键。
"""
return ["prediction", "prediction_b"]
@property
def evaluation_name(self) -> str:
"""获取评估名称。
返回:
str:评估名称。
"""
return f"pairwise_{self.distance.value}_distance"
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""计算两个预测之间的字符串距离。
参数:
inputs (Dict[str, Any]): 输入数值。
run_manager (CallbackManagerForChainRun , optional):
回调管理器。
返回:
Dict[str, Any]: 包含得分的评估结果。
"""
return {
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"])
}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""异步计算两个预测之间的字符串距离。
参数:
inputs (Dict[str, Any]): 输入数值。
run_manager (AsyncCallbackManagerForChainRun , optional):
回调管理器。
返回:
Dict[str, Any]: 包含分数的评估结果。
"""
return {
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"])
}
def _evaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""评估两个预测之间的字符串距离。
参数:
prediction (str): 第一个预测字符串。
prediction_b (str): 第二个预测字符串。
callbacks (Callbacks, optional): 要使用的回调函数。
tags (List[str], optional): 应用于跟踪的标签。
metadata (Dict[str, Any], optional): 应用于跟踪的元数据。
**kwargs: 额外的关键字参数。
返回:
dict: 包含分数的评估结果。
"""
result = self(
inputs={"prediction": prediction, "prediction_b": prediction_b},
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)
async def _aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""异步地评估两个预测之间的字符串距离。
参数:
prediction (str): 第一个预测字符串。
prediction_b (str): 第二个预测字符串。
callbacks (Callbacks, optional): 要使用的回调。
tags (List[str], optional): 应用于跟踪的标签。
metadata (Dict[str, Any], optional): 应用于跟踪的元数据。
**kwargs: 附加的关键字参数。
返回:
dict: 包含分数的评估结果。
"""
result = await self.acall(
inputs={"prediction": prediction, "prediction_b": prediction_b},
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)