Source code for langchain_community.example_selectors.ngram_overlap
"""根据ngram重叠得分(来自NLTK包中的sentence_bleu得分)选择和排序示例。
"""
from typing import Dict, List
import numpy as np
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
[docs]def ngram_overlap_score(source: List[str], example: List[str]) -> float:
"""计算源文本和示例文本的ngram重叠得分,作为句子BLEU分数,使用NLTK包。
使用带有method1平滑函数和自动重新加权的sentence_bleu。
返回值为0.0到1.0之间的浮点数。
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore
sentence_bleu,
)
hypotheses = source[0].split()
references = [s.split() for s in example]
return float(
sentence_bleu(
references,
hypotheses,
smoothing_function=SmoothingFunction().method1,
auto_reweigh=True,
)
)
[docs]class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
"""根据ngram重叠得分(来自NLTK软件包的sentence_bleu得分)选择和排序示例。"""
examples: List[dict]
"""提示模板期望的示例列表。"""
example_prompt: PromptTemplate
"""用于格式化示例的提示模板。"""
threshold: float = -1.0
"""算法停止的阈值。默认设置为-1.0。
对于负阈值:
select_examples按ngram_overlap_score对示例进行排序,但不排除任何示例。
对于大于1.0的阈值:
select_examples排除所有示例,并返回一个空列表。
对于等于0.0的阈值:
select_examples按ngram_overlap_score对示例进行排序,
并排除与输入没有ngram重叠的示例。"""
@root_validator(pre=True)
def check_dependencies(cls, values: Dict) -> Dict:
"""检查是否存在有效的依赖关系。"""
try:
from nltk.translate.bleu_score import ( # noqa: F401
SmoothingFunction,
sentence_bleu,
)
except ImportError as e:
raise ImportError(
"Not all the correct dependencies for this ExampleSelect exist."
"Please install nltk with `pip install nltk`."
) from e
return values
[docs] def add_example(self, example: Dict[str, str]) -> None:
"""向列表中添加新的示例。"""
self.examples.append(example)
[docs] def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""返回根据输入的ngram_overlap_score排序的示例列表。
降序排列。
排除任何ngram_overlap_score小于或等于阈值的示例。
"""
inputs = list(input_variables.values())
examples = []
k = len(self.examples)
score = [0.0] * k
first_prompt_template_key = self.example_prompt.input_variables[0]
for i in range(k):
score[i] = ngram_overlap_score(
inputs, [self.examples[i][first_prompt_template_key]]
)
while True:
arg_max = np.argmax(score)
if (score[arg_max] < self.threshold) or abs(
score[arg_max] - self.threshold
) < 1e-9:
break
examples.append(self.examples[arg_max])
score[arg_max] = self.threshold - 1.0
return examples