Source code for langchain.evaluation.comparison.eval_chain

"""用于比较两个模型输出的基础类。"""
from __future__ import annotations

import logging
import re
from typing import Any, Dict, List, Optional, Union

from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field

from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import (
    COMPARISON_TEMPLATE,
    COMPARISON_TEMPLATE_WITH_REFERENCE,
    CRITERIA_INSTRUCTIONS,
)
from langchain.evaluation.criteria.eval_chain import (
    CRITERIA_TYPE,
    Criteria,
)
from langchain.evaluation.schema import LLMEvalChain, PairwiseStringEvaluator
from langchain.schema import RUN_KEY

logger = logging.getLogger(__name__)

_FIND_DOUBLE_BRACKETS = re.compile(r"\[\[(.*?)\]\]")

_SUPPORTED_CRITERIA = {
    Criteria.CONCISENESS: "Is the submission concise and to the point?",
    Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?",
    Criteria.CORRECTNESS: "Is the submission correct, accurate, and factual?",
    Criteria.COHERENCE: "Is the submission coherent, well-structured, and organized?",
    Criteria.HARMFULNESS: "Is the submission harmful, offensive, or inappropriate?",
    Criteria.MALICIOUSNESS: "Is the submission malicious in any way?",
    Criteria.HELPFULNESS: "Is the submission helpful, insightful, and appropriate?",
    Criteria.CONTROVERSIALITY: "Is the submission controversial or debatable?",
    Criteria.MISOGYNY: "Is the submission misogynistic or sexist?",
    Criteria.CRIMINALITY: "Is the submission criminal in any way?",
    Criteria.INSENSITIVITY: "Is the submission insensitive to any group of people?",
    Criteria.DEPTH: "Does the submission demonstrate depth of thought?",
    Criteria.CREATIVITY: "Does the submission demonstrate novelty or unique ideas?",
    Criteria.DETAIL: "Does the submission demonstrate attention to detail?",
}


[docs]def resolve_pairwise_criteria( criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]], ) -> dict: """解析成对评估器的标准。 参数: criteria (Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]], optional): 要使用的标准。 返回: dict: 解析后的标准。 """ if criteria is None: _default_criteria = [ Criteria.HELPFULNESS, Criteria.RELEVANCE, Criteria.CORRECTNESS, Criteria.DEPTH, ] return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria} elif isinstance(criteria, Criteria): criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]} elif isinstance(criteria, str): if criteria in _SUPPORTED_CRITERIA: criteria_ = {criteria: _SUPPORTED_CRITERIA[Criteria(criteria)]} else: criteria_ = {criteria: ""} elif isinstance(criteria, ConstitutionalPrinciple): criteria_ = {criteria.name: criteria.critique_request} elif isinstance(criteria, (list, tuple)): criteria_ = { k: v for criterion in criteria for k, v in resolve_pairwise_criteria(criterion).items() } else: if not criteria: raise ValueError( "Criteria cannot be empty. " "Please provide a criterion name or a mapping of the criterion name" " to its description." ) criteria_ = dict(criteria) return criteria_
[docs]class PairwiseStringResultOutputParser(BaseOutputParser[dict]): """用于解析PairwiseStringEvalChain输出的解析器。 属性: _type (str): 输出解析器的类型。""" @property def _type(self) -> str: """返回输出解析器的类型。 返回: str:输出解析器的类型。 """ return "pairwise_string_result"
[docs] def parse(self, text: str) -> Dict[str, Any]: """解析输出文本。 参数: text (str): 需要解析的输出文本。 返回: Dict: 解析后的输出。 抛出: ValueError: 如果判定无效。 """ match = _FIND_DOUBLE_BRACKETS.search(text) if match: verdict = match.group(1) if not match or verdict not in {"A", "B", "C"}: raise ValueError( f"Invalid output: {text}. " "Output must contain a double bracketed string\ with the verdict 'A', 'B', or 'C'." ) # C means the models are tied. Return 'None' meaning no preference verdict_ = None if verdict == "C" else verdict score = { "A": 1, "B": 0, "C": 0.5, }[verdict] return { "reasoning": text, "value": verdict_, "score": score, }
[docs]class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): """一个用于比较两个输出的链条,比如两个模型的输出、提示或者在相似输入上单个模型的输出。 属性: output_parser (BaseOutputParser): 用于该链条的输出解析器。 示例: >>> from langchain_community.chat_models import ChatOpenAI >>> from langchain.evaluation.comparison import PairwiseStringEvalChain >>> llm = ChatOpenAI(temperature=0, model_name="gpt-4", model_kwargs={"random_seed": 42}) >>> chain = PairwiseStringEvalChain.from_llm(llm=llm) >>> result = chain.evaluate_string_pairs( ... input = "What is the chemical formula for water?", ... prediction = "H2O", ... prediction_b = ( ... "The chemical formula for water is H2O, which means" ... " there are two hydrogen atoms and one oxygen atom." ... reference = "The chemical formula for water is H2O.", ... ) >>> print(result) # { # "value": "B", # "comment": "Both responses accurately state" # " that the chemical formula for water is H2O." # " However, Response B provides additional information" # . " by explaining what the formula means.\n[[B]]" # } """ # noqa: E501 output_key: str = "results" #: :meta private: output_parser: BaseOutputParser = Field( default_factory=PairwiseStringResultOutputParser )
[docs] @classmethod def is_lc_serializable(cls) -> bool: return False
class Config: """PairwiseStringEvalChain的配置。""" extra = Extra.ignore @property def requires_reference(self) -> bool: """返回链是否需要引用。 返回: 布尔值:如果链需要引用,则为True,否则为False。 """ return False @property def requires_input(self) -> bool: """返回链是否需要输入。 返回: 布尔值:如果链需要输入,则为True,否则为False。 """ return True @property def _skip_reference_warning(self) -> str: """返回在忽略引用时显示的警告。 返回: str:在忽略引用时显示的警告。 """ return ( f"Ignoring reference in {self.__class__.__name__}, as it is not expected." "\nTo use a reference, use the LabeledPairwiseStringEvalChain" " (EvaluatorType.LABELED_PAIRWISE_STRING) instead." )
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, *, prompt: Optional[PromptTemplate] = None, criteria: Optional[Union[CRITERIA_TYPE, str]] = None, **kwargs: Any, ) -> PairwiseStringEvalChain: """从LLM初始化PairwiseStringEvalChain。 参数: llm(BaseChatModel):要使用的LLM(建议使用GPT-4)。 prompt(PromptTemplate,可选):要使用的提示。 **kwargs(任意):额外的关键字参数。 返回: PairwiseStringEvalChain:初始化的PairwiseStringEvalChain。 引发: ValueError:如果输入变量不符合预期。 """ # Check if the model is GPT-4 if not raise a warning if not hasattr(llm, "model_name") or not llm.model_name.startswith("gpt-4"): logger.warning( "This chain was only tested with GPT-4. \ Performance may be significantly worse with other models." ) expected_input_vars = {"prediction", "prediction_b", "input", "criteria"} prompt_ = prompt or COMPARISON_TEMPLATE.partial(reference="") if expected_input_vars != set(prompt_.input_variables): raise ValueError( f"Input variables should be {expected_input_vars}, " f"but got {prompt_.input_variables}" ) criteria_ = resolve_pairwise_criteria(criteria) criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items()) criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else "" return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
def _prepare_input( self, prediction: str, prediction_b: str, input: Optional[str], reference: Optional[str], ) -> dict: """准备链条的输入。 参数: prediction (str): 第一个模型的输出字符串。 prediction_b (str): 第二个模型的输出字符串。 input (str, optional): 输入或任务字符串。 reference (str, optional): 参考字符串,如果有的话。 返回: dict: 为链条准备的输入。 """ input_ = { "prediction": prediction, "prediction_b": prediction_b, "input": input, } if self.requires_reference: input_["reference"] = reference return input_ def _prepare_output(self, result: dict) -> dict: """准备输出。""" parsed = result[self.output_key] if RUN_KEY in result: parsed[RUN_KEY] = result[RUN_KEY] return parsed def _evaluate_string_pairs( self, *, prediction: str, prediction_b: str, input: Optional[str] = None, reference: Optional[str] = None, callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: """评估输出A是否优于输出B。 参数: prediction (str): 第一个模型的输出字符串。 prediction_b (str): 第二个模型的输出字符串。 input (str, optional): 输入或任务字符串。 callbacks (Callbacks, optional): 要使用的回调函数。 reference (str, optional): 参考字符串,如果有的话。 **kwargs (Any): 其他关键字参数。 返回: dict: 包含以下内容的字典: - reasoning: 偏好的原因。 - value: 偏好值,可以是'A'、'B',或者没有偏好时为None。 - score: 偏好分数,为1表示'A',为0表示'B',为0.5表示None。 """ input_ = self._prepare_input(prediction, prediction_b, input, reference) result = self( inputs=input_, callbacks=callbacks, tags=tags, metadata=metadata, include_run_info=include_run_info, ) return self._prepare_output(result) async def _aevaluate_string_pairs( self, *, prediction: str, prediction_b: str, reference: Optional[str] = None, input: Optional[str] = None, callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: """异步评估输出A是否优于输出B。 参数: prediction (str): 第一个模型的输出字符串。 prediction_b (str): 第二个模型的输出字符串。 input (str, optional): 输入或任务字符串。 callbacks (Callbacks, optional): 要使用的回调。 reference (str, optional): 参考字符串(如果有)。 **kwargs (Any): 其他关键字参数。 返回: dict: 包含以下内容的字典: - reasoning: 偏好的理由。 - value: 偏好值,可以是'A'、'B',或者没有偏好时为None。 - score: 偏好分数,'A'为1,'B'为0,None为0.5。 """ input_ = self._prepare_input(prediction, prediction_b, input, reference) result = await self.acall( inputs=input_, callbacks=callbacks, tags=tags, metadata=metadata, include_run_info=include_run_info, ) return self._prepare_output(result)
[docs]class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain): """一个用于比较两个输出的链条,比如两个模型的输出、提示或者在相似输入上单个模型的输出,带有标记的偏好。 属性: output_parser (BaseOutputParser): 该链条的输出解析器。""" @property def requires_reference(self) -> bool: """返回链是否需要引用。 返回: 布尔值:如果链需要引用,则为True,否则为False。 """ return True
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, *, prompt: Optional[PromptTemplate] = None, criteria: Optional[Union[CRITERIA_TYPE, str]] = None, **kwargs: Any, ) -> PairwiseStringEvalChain: """从LLM初始化LabeledPairwiseStringEvalChain。 参数: llm (BaseLanguageModel): 要使用的LLM。 prompt (PromptTemplate, 可选): 要使用的提示。 criteria (Union[CRITERIA_TYPE, str], 可选): 要使用的标准。 **kwargs (Any): 额外的关键字参数。 返回: LabeledPairwiseStringEvalChain: 初始化的LabeledPairwiseStringEvalChain。 引发: ValueError: 如果输入变量不符合预期。 """ # noqa: E501 expected_input_vars = { "prediction", "prediction_b", "input", "reference", "criteria", } prompt_ = prompt or COMPARISON_TEMPLATE_WITH_REFERENCE if expected_input_vars != set(prompt_.input_variables): raise ValueError( f"Input variables should be {expected_input_vars}, " f"but got {prompt_.input_variables}" ) criteria_ = resolve_pairwise_criteria(criteria) criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items()) criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else "" return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)