"""用于评估问答的LLM链。"""
from __future__ import annotations
import re
import string
from typing import Any, List, Optional, Sequence, Tuple
from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
from langchain.schema import RUN_KEY
def _get_score(text: str) -> Optional[Tuple[str, int]]:
match = re.search(r"grade:\s*(correct|incorrect)", text.strip(), re.IGNORECASE)
if match:
if match.group(1).upper() == "CORRECT":
return "CORRECT", 1
elif match.group(1).upper() == "INCORRECT":
return "INCORRECT", 0
try:
first_word = (
text.strip().split()[0].translate(str.maketrans("", "", string.punctuation))
)
if first_word.upper() == "CORRECT":
return "CORRECT", 1
elif first_word.upper() == "INCORRECT":
return "INCORRECT", 0
last_word = (
text.strip()
.split()[-1]
.translate(str.maketrans("", "", string.punctuation))
)
if last_word.upper() == "CORRECT":
return "CORRECT", 1
elif last_word.upper() == "INCORRECT":
return "INCORRECT", 0
except IndexError:
pass
return None
def _parse_string_eval_output(text: str) -> dict:
"""解析输出文本。
参数:
text (str): 需要解析的输出文本。
返回:
Any: 解析后的输出。
"""
reasoning = text.strip()
parsed_scores = _get_score(reasoning)
if parsed_scores is None:
value, score = None, None
else:
value, score = parsed_scores
return {
"reasoning": reasoning,
"value": value,
"score": score,
}
[docs]class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
"""用于评估问答的LLM链。"""
output_key: str = "results" #: :meta private:
class Config:
"""QAEvalChain的配置。"""
extra = Extra.ignore
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def evaluation_name(self) -> str:
return "correctness"
@property
def requires_reference(self) -> bool:
return True
@property
def requires_input(self) -> bool:
return True
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> QAEvalChain:
"""从LLM加载QA评估链。
参数:
llm(BaseLanguageModel):要使用的基础语言模型。
prompt(PromptTemplate):包含输入变量 'input'、'answer' 和 'result' 的提示模板,将用作评估的提示。
默认为 PROMPT。
**kwargs:额外的关键字参数。
返回:
QAEvalChain:加载的QA评估链。
"""
prompt = prompt or PROMPT
expected_input_vars = {"query", "answer", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
return cls(llm=llm, prompt=prompt, **kwargs)
[docs] def evaluate(
self,
examples: Sequence[dict],
predictions: Sequence[dict],
question_key: str = "query",
answer_key: str = "answer",
prediction_key: str = "result",
*,
callbacks: Callbacks = None,
) -> List[dict]:
"""评估问答示例和预测。"""
inputs = [
{
"query": example[question_key],
"answer": example[answer_key],
"result": predictions[i][prediction_key],
}
for i, example in enumerate(examples)
]
return self.apply(inputs, callbacks=callbacks)
def _prepare_output(self, result: dict) -> dict:
parsed_result = _parse_string_eval_output(result[self.output_key])
if RUN_KEY in result:
parsed_result[RUN_KEY] = result[RUN_KEY]
return parsed_result
def _evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
"""评估链或LLM输出,基于可选的输入和标签。
参数:
prediction(str):要评估的LLM或链预测。
reference(Optional[str],可选):要评估的参考标签。
input(Optional[str],可选):在评估过程中要考虑的输入。
callbacks(Callbacks,可选):用于跟踪的回调。
include_run_info(bool,可选):是否在返回的结果中包含运行信息。
**kwargs:其他关键字参数,包括回调、标签等。
返回:
dict:包含得分或值的评估结果。
"""
result = self(
{
"query": input,
"answer": reference,
"result": prediction,
},
callbacks=callbacks,
include_run_info=include_run_info,
)
return self._prepare_output(result)
async def _aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
result = await self.acall(
inputs={"query": input, "answer": reference, "result": prediction},
callbacks=callbacks,
include_run_info=include_run_info,
)
return self._prepare_output(result)
[docs]class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
"""用于根据上下文评估没有基于GT的QA的LLM链。"""
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def requires_reference(self) -> bool:
"""链是否需要引用字符串。"""
return True
@property
def requires_input(self) -> bool:
"""链是否需要输入字符串。"""
return True
class Config:
"""QAEvalChain的配置。"""
extra = Extra.ignore
@classmethod
def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
expected_input_vars = {"query", "context", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
@property
def evaluation_name(self) -> str:
return "Contextual Accuracy"
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> ContextQAEvalChain:
"""从LLM加载QA评估链。
参数:
llm(BaseLanguageModel):要使用的基础语言模型。
prompt(PromptTemplate):包含输入变量'query'、'context'和'result'的提示模板,将用作评估的提示。
默认为PROMPT。
**kwargs:额外的关键字参数。
返回:
ContextQAEvalChain:加载的QA评估链。
"""
prompt = prompt or CONTEXT_PROMPT
cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs)
[docs] def evaluate(
self,
examples: List[dict],
predictions: List[dict],
question_key: str = "query",
context_key: str = "context",
prediction_key: str = "result",
*,
callbacks: Callbacks = None,
) -> List[dict]:
"""评估问答示例和预测。"""
inputs = [
{
"query": example[question_key],
"context": example[context_key],
"result": predictions[i][prediction_key],
}
for i, example in enumerate(examples)
]
return self.apply(inputs, callbacks=callbacks)
def _prepare_output(self, result: dict) -> dict:
parsed_result = _parse_string_eval_output(result[self.output_key])
if RUN_KEY in result:
parsed_result[RUN_KEY] = result[RUN_KEY]
return parsed_result
def _evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
result = self(
{
"query": input,
"context": reference,
"result": prediction,
},
callbacks=callbacks,
include_run_info=include_run_info,
)
return self._prepare_output(result)
async def _aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
result = await self.acall(
inputs={"query": input, "context": reference, "result": prediction},
callbacks=callbacks,
include_run_info=include_run_info,
)
return self._prepare_output(result)
[docs]class CotQAEvalChain(ContextQAEvalChain):
"""LLM链用于使用思维链推理评估问答。"""
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def evaluation_name(self) -> str:
return "COT Contextual Accuracy"
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> CotQAEvalChain:
"""从LLM加载QA评估链。"""
prompt = prompt or COT_PROMPT
cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs)