"""用于比较两个模型输出的基础类。"""
from __future__ import annotations
import logging
import re
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import (
COMPARISON_TEMPLATE,
COMPARISON_TEMPLATE_WITH_REFERENCE,
CRITERIA_INSTRUCTIONS,
)
from langchain.evaluation.criteria.eval_chain import (
CRITERIA_TYPE,
Criteria,
)
from langchain.evaluation.schema import LLMEvalChain, PairwiseStringEvaluator
from langchain.schema import RUN_KEY
logger = logging.getLogger(__name__)
_FIND_DOUBLE_BRACKETS = re.compile(r"\[\[(.*?)\]\]")
_SUPPORTED_CRITERIA = {
Criteria.CONCISENESS: "Is the submission concise and to the point?",
Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?",
Criteria.CORRECTNESS: "Is the submission correct, accurate, and factual?",
Criteria.COHERENCE: "Is the submission coherent, well-structured, and organized?",
Criteria.HARMFULNESS: "Is the submission harmful, offensive, or inappropriate?",
Criteria.MALICIOUSNESS: "Is the submission malicious in any way?",
Criteria.HELPFULNESS: "Is the submission helpful, insightful, and appropriate?",
Criteria.CONTROVERSIALITY: "Is the submission controversial or debatable?",
Criteria.MISOGYNY: "Is the submission misogynistic or sexist?",
Criteria.CRIMINALITY: "Is the submission criminal in any way?",
Criteria.INSENSITIVITY: "Is the submission insensitive to any group of people?",
Criteria.DEPTH: "Does the submission demonstrate depth of thought?",
Criteria.CREATIVITY: "Does the submission demonstrate novelty or unique ideas?",
Criteria.DETAIL: "Does the submission demonstrate attention to detail?",
}
[docs]def resolve_pairwise_criteria(
criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]],
) -> dict:
"""解析成对评估器的标准。
参数:
criteria (Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]], optional):
要使用的标准。
返回:
dict: 解析后的标准。
"""
if criteria is None:
_default_criteria = [
Criteria.HELPFULNESS,
Criteria.RELEVANCE,
Criteria.CORRECTNESS,
Criteria.DEPTH,
]
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
elif isinstance(criteria, Criteria):
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, str):
if criteria in _SUPPORTED_CRITERIA:
criteria_ = {criteria: _SUPPORTED_CRITERIA[Criteria(criteria)]}
else:
criteria_ = {criteria: ""}
elif isinstance(criteria, ConstitutionalPrinciple):
criteria_ = {criteria.name: criteria.critique_request}
elif isinstance(criteria, (list, tuple)):
criteria_ = {
k: v
for criterion in criteria
for k, v in resolve_pairwise_criteria(criterion).items()
}
else:
if not criteria:
raise ValueError(
"Criteria cannot be empty. "
"Please provide a criterion name or a mapping of the criterion name"
" to its description."
)
criteria_ = dict(criteria)
return criteria_
[docs]class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
"""用于解析PairwiseStringEvalChain输出的解析器。
属性:
_type (str): 输出解析器的类型。"""
@property
def _type(self) -> str:
"""返回输出解析器的类型。
返回:
str:输出解析器的类型。
"""
return "pairwise_string_result"
[docs] def parse(self, text: str) -> Dict[str, Any]:
"""解析输出文本。
参数:
text (str): 需要解析的输出文本。
返回:
Dict: 解析后的输出。
抛出:
ValueError: 如果判定无效。
"""
match = _FIND_DOUBLE_BRACKETS.search(text)
if match:
verdict = match.group(1)
if not match or verdict not in {"A", "B", "C"}:
raise ValueError(
f"Invalid output: {text}. "
"Output must contain a double bracketed string\
with the verdict 'A', 'B', or 'C'."
)
# C means the models are tied. Return 'None' meaning no preference
verdict_ = None if verdict == "C" else verdict
score = {
"A": 1,
"B": 0,
"C": 0.5,
}[verdict]
return {
"reasoning": text,
"value": verdict_,
"score": score,
}
[docs]class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
"""一个用于比较两个输出的链条,比如两个模型的输出、提示或者在相似输入上单个模型的输出。
属性:
output_parser (BaseOutputParser): 用于该链条的输出解析器。
示例:
>>> from langchain_community.chat_models import ChatOpenAI
>>> from langchain.evaluation.comparison import PairwiseStringEvalChain
>>> llm = ChatOpenAI(temperature=0, model_name="gpt-4", model_kwargs={"random_seed": 42})
>>> chain = PairwiseStringEvalChain.from_llm(llm=llm)
>>> result = chain.evaluate_string_pairs(
... input = "What is the chemical formula for water?",
... prediction = "H2O",
... prediction_b = (
... "The chemical formula for water is H2O, which means"
... " there are two hydrogen atoms and one oxygen atom."
... reference = "The chemical formula for water is H2O.",
... )
>>> print(result)
# {
# "value": "B",
# "comment": "Both responses accurately state"
# " that the chemical formula for water is H2O."
# " However, Response B provides additional information"
# . " by explaining what the formula means.\n[[B]]"
# }
""" # noqa: E501
output_key: str = "results" #: :meta private:
output_parser: BaseOutputParser = Field(
default_factory=PairwiseStringResultOutputParser
)
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return False
class Config:
"""PairwiseStringEvalChain的配置。"""
extra = Extra.ignore
@property
def requires_reference(self) -> bool:
"""返回链是否需要引用。
返回:
布尔值:如果链需要引用,则为True,否则为False。
"""
return False
@property
def requires_input(self) -> bool:
"""返回链是否需要输入。
返回:
布尔值:如果链需要输入,则为True,否则为False。
"""
return True
@property
def _skip_reference_warning(self) -> str:
"""返回在忽略引用时显示的警告。
返回:
str:在忽略引用时显示的警告。
"""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
"\nTo use a reference, use the LabeledPairwiseStringEvalChain"
" (EvaluatorType.LABELED_PAIRWISE_STRING) instead."
)
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
prompt: Optional[PromptTemplate] = None,
criteria: Optional[Union[CRITERIA_TYPE, str]] = None,
**kwargs: Any,
) -> PairwiseStringEvalChain:
"""从LLM初始化PairwiseStringEvalChain。
参数:
llm(BaseChatModel):要使用的LLM(建议使用GPT-4)。
prompt(PromptTemplate,可选):要使用的提示。
**kwargs(任意):额外的关键字参数。
返回:
PairwiseStringEvalChain:初始化的PairwiseStringEvalChain。
引发:
ValueError:如果输入变量不符合预期。
"""
# Check if the model is GPT-4 if not raise a warning
if not hasattr(llm, "model_name") or not llm.model_name.startswith("gpt-4"):
logger.warning(
"This chain was only tested with GPT-4. \
Performance may be significantly worse with other models."
)
expected_input_vars = {"prediction", "prediction_b", "input", "criteria"}
prompt_ = prompt or COMPARISON_TEMPLATE.partial(reference="")
if expected_input_vars != set(prompt_.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt_.input_variables}"
)
criteria_ = resolve_pairwise_criteria(criteria)
criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items())
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
def _prepare_input(
self,
prediction: str,
prediction_b: str,
input: Optional[str],
reference: Optional[str],
) -> dict:
"""准备链条的输入。
参数:
prediction (str): 第一个模型的输出字符串。
prediction_b (str): 第二个模型的输出字符串。
input (str, optional): 输入或任务字符串。
reference (str, optional): 参考字符串,如果有的话。
返回:
dict: 为链条准备的输入。
"""
input_ = {
"prediction": prediction,
"prediction_b": prediction_b,
"input": input,
}
if self.requires_reference:
input_["reference"] = reference
return input_
def _prepare_output(self, result: dict) -> dict:
"""准备输出。"""
parsed = result[self.output_key]
if RUN_KEY in result:
parsed[RUN_KEY] = result[RUN_KEY]
return parsed
def _evaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
input: Optional[str] = None,
reference: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""评估输出A是否优于输出B。
参数:
prediction (str): 第一个模型的输出字符串。
prediction_b (str): 第二个模型的输出字符串。
input (str, optional): 输入或任务字符串。
callbacks (Callbacks, optional): 要使用的回调函数。
reference (str, optional): 参考字符串,如果有的话。
**kwargs (Any): 其他关键字参数。
返回:
dict: 包含以下内容的字典:
- reasoning: 偏好的原因。
- value: 偏好值,可以是'A'、'B',或者没有偏好时为None。
- score: 偏好分数,为1表示'A',为0表示'B',为0.5表示None。
"""
input_ = self._prepare_input(prediction, prediction_b, input, reference)
result = self(
inputs=input_,
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)
async def _aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""异步评估输出A是否优于输出B。
参数:
prediction (str): 第一个模型的输出字符串。
prediction_b (str): 第二个模型的输出字符串。
input (str, optional): 输入或任务字符串。
callbacks (Callbacks, optional): 要使用的回调。
reference (str, optional): 参考字符串(如果有)。
**kwargs (Any): 其他关键字参数。
返回:
dict: 包含以下内容的字典:
- reasoning: 偏好的理由。
- value: 偏好值,可以是'A'、'B',或者没有偏好时为None。
- score: 偏好分数,'A'为1,'B'为0,None为0.5。
"""
input_ = self._prepare_input(prediction, prediction_b, input, reference)
result = await self.acall(
inputs=input_,
callbacks=callbacks,
tags=tags,
metadata=metadata,
include_run_info=include_run_info,
)
return self._prepare_output(result)
[docs]class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
"""一个用于比较两个输出的链条,比如两个模型的输出、提示或者在相似输入上单个模型的输出,带有标记的偏好。
属性:
output_parser (BaseOutputParser): 该链条的输出解析器。"""
@property
def requires_reference(self) -> bool:
"""返回链是否需要引用。
返回:
布尔值:如果链需要引用,则为True,否则为False。
"""
return True
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
prompt: Optional[PromptTemplate] = None,
criteria: Optional[Union[CRITERIA_TYPE, str]] = None,
**kwargs: Any,
) -> PairwiseStringEvalChain:
"""从LLM初始化LabeledPairwiseStringEvalChain。
参数:
llm (BaseLanguageModel): 要使用的LLM。
prompt (PromptTemplate, 可选): 要使用的提示。
criteria (Union[CRITERIA_TYPE, str], 可选): 要使用的标准。
**kwargs (Any): 额外的关键字参数。
返回:
LabeledPairwiseStringEvalChain: 初始化的LabeledPairwiseStringEvalChain。
引发:
ValueError: 如果输入变量不符合预期。
""" # noqa: E501
expected_input_vars = {
"prediction",
"prediction_b",
"input",
"reference",
"criteria",
}
prompt_ = prompt or COMPARISON_TEMPLATE_WITH_REFERENCE
if expected_input_vars != set(prompt_.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt_.input_variables}"
)
criteria_ = resolve_pairwise_criteria(criteria)
criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items())
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)