Source code for langchain.evaluation.qa.eval_chain

"""用于评估问答的LLM链。"""
from __future__ import annotations

import re
import string
from typing import Any, List, Optional, Sequence, Tuple

from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra

from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
from langchain.schema import RUN_KEY


def _get_score(text: str) -> Optional[Tuple[str, int]]:
    match = re.search(r"grade:\s*(correct|incorrect)", text.strip(), re.IGNORECASE)
    if match:
        if match.group(1).upper() == "CORRECT":
            return "CORRECT", 1
        elif match.group(1).upper() == "INCORRECT":
            return "INCORRECT", 0
    try:
        first_word = (
            text.strip().split()[0].translate(str.maketrans("", "", string.punctuation))
        )
        if first_word.upper() == "CORRECT":
            return "CORRECT", 1
        elif first_word.upper() == "INCORRECT":
            return "INCORRECT", 0
        last_word = (
            text.strip()
            .split()[-1]
            .translate(str.maketrans("", "", string.punctuation))
        )
        if last_word.upper() == "CORRECT":
            return "CORRECT", 1
        elif last_word.upper() == "INCORRECT":
            return "INCORRECT", 0
    except IndexError:
        pass
    return None


def _parse_string_eval_output(text: str) -> dict:
    """解析输出文本。

参数:
    text (str): 需要解析的输出文本。

返回:
    Any: 解析后的输出。
"""
    reasoning = text.strip()
    parsed_scores = _get_score(reasoning)
    if parsed_scores is None:
        value, score = None, None
    else:
        value, score = parsed_scores
    return {
        "reasoning": reasoning,
        "value": value,
        "score": score,
    }


[docs]class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): """用于评估问答的LLM链。""" output_key: str = "results" #: :meta private: class Config: """QAEvalChain的配置。""" extra = Extra.ignore
[docs] @classmethod def is_lc_serializable(cls) -> bool: return False
@property def evaluation_name(self) -> str: return "correctness" @property def requires_reference(self) -> bool: return True @property def requires_input(self) -> bool: return True
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, **kwargs: Any, ) -> QAEvalChain: """从LLM加载QA评估链。 参数: llm(BaseLanguageModel):要使用的基础语言模型。 prompt(PromptTemplate):包含输入变量 'input'、'answer' 和 'result' 的提示模板,将用作评估的提示。 默认为 PROMPT。 **kwargs:额外的关键字参数。 返回: QAEvalChain:加载的QA评估链。 """ prompt = prompt or PROMPT expected_input_vars = {"query", "answer", "result"} if expected_input_vars != set(prompt.input_variables): raise ValueError( f"Input variables should be {expected_input_vars}, " f"but got {prompt.input_variables}" ) return cls(llm=llm, prompt=prompt, **kwargs)
[docs] def evaluate( self, examples: Sequence[dict], predictions: Sequence[dict], question_key: str = "query", answer_key: str = "answer", prediction_key: str = "result", *, callbacks: Callbacks = None, ) -> List[dict]: """评估问答示例和预测。""" inputs = [ { "query": example[question_key], "answer": example[answer_key], "result": predictions[i][prediction_key], } for i, example in enumerate(examples) ] return self.apply(inputs, callbacks=callbacks)
def _prepare_output(self, result: dict) -> dict: parsed_result = _parse_string_eval_output(result[self.output_key]) if RUN_KEY in result: parsed_result[RUN_KEY] = result[RUN_KEY] return parsed_result def _evaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, callbacks: Callbacks = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: """评估链或LLM输出,基于可选的输入和标签。 参数: prediction(str):要评估的LLM或链预测。 reference(Optional[str],可选):要评估的参考标签。 input(Optional[str],可选):在评估过程中要考虑的输入。 callbacks(Callbacks,可选):用于跟踪的回调。 include_run_info(bool,可选):是否在返回的结果中包含运行信息。 **kwargs:其他关键字参数,包括回调、标签等。 返回: dict:包含得分或值的评估结果。 """ result = self( { "query": input, "answer": reference, "result": prediction, }, callbacks=callbacks, include_run_info=include_run_info, ) return self._prepare_output(result) async def _aevaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, callbacks: Callbacks = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: result = await self.acall( inputs={"query": input, "answer": reference, "result": prediction}, callbacks=callbacks, include_run_info=include_run_info, ) return self._prepare_output(result)
[docs]class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): """用于根据上下文评估没有基于GT的QA的LLM链。"""
[docs] @classmethod def is_lc_serializable(cls) -> bool: return False
@property def requires_reference(self) -> bool: """链是否需要引用字符串。""" return True @property def requires_input(self) -> bool: """链是否需要输入字符串。""" return True class Config: """QAEvalChain的配置。""" extra = Extra.ignore @classmethod def _validate_input_vars(cls, prompt: PromptTemplate) -> None: expected_input_vars = {"query", "context", "result"} if expected_input_vars != set(prompt.input_variables): raise ValueError( f"Input variables should be {expected_input_vars}, " f"but got {prompt.input_variables}" ) @property def evaluation_name(self) -> str: return "Contextual Accuracy"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, **kwargs: Any, ) -> ContextQAEvalChain: """从LLM加载QA评估链。 参数: llm(BaseLanguageModel):要使用的基础语言模型。 prompt(PromptTemplate):包含输入变量'query'、'context'和'result'的提示模板,将用作评估的提示。 默认为PROMPT。 **kwargs:额外的关键字参数。 返回: ContextQAEvalChain:加载的QA评估链。 """ prompt = prompt or CONTEXT_PROMPT cls._validate_input_vars(prompt) return cls(llm=llm, prompt=prompt, **kwargs)
[docs] def evaluate( self, examples: List[dict], predictions: List[dict], question_key: str = "query", context_key: str = "context", prediction_key: str = "result", *, callbacks: Callbacks = None, ) -> List[dict]: """评估问答示例和预测。""" inputs = [ { "query": example[question_key], "context": example[context_key], "result": predictions[i][prediction_key], } for i, example in enumerate(examples) ] return self.apply(inputs, callbacks=callbacks)
def _prepare_output(self, result: dict) -> dict: parsed_result = _parse_string_eval_output(result[self.output_key]) if RUN_KEY in result: parsed_result[RUN_KEY] = result[RUN_KEY] return parsed_result def _evaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, callbacks: Callbacks = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: result = self( { "query": input, "context": reference, "result": prediction, }, callbacks=callbacks, include_run_info=include_run_info, ) return self._prepare_output(result) async def _aevaluate_strings( self, *, prediction: str, reference: Optional[str] = None, input: Optional[str] = None, callbacks: Callbacks = None, include_run_info: bool = False, **kwargs: Any, ) -> dict: result = await self.acall( inputs={"query": input, "context": reference, "result": prediction}, callbacks=callbacks, include_run_info=include_run_info, ) return self._prepare_output(result)
[docs]class CotQAEvalChain(ContextQAEvalChain): """LLM链用于使用思维链推理评估问答。"""
[docs] @classmethod def is_lc_serializable(cls) -> bool: return False
@property def evaluation_name(self) -> str: return "COT Contextual Accuracy"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, **kwargs: Any, ) -> CotQAEvalChain: """从LLM加载QA评估链。""" prompt = prompt or COT_PROMPT cls._validate_input_vars(prompt) return cls(llm=llm, prompt=prompt, **kwargs)