Source code for langchain_community.example_selectors.ngram_overlap

"""根据ngram重叠得分(来自NLTK包中的sentence_bleu得分)选择和排序示例。
"""
from typing import Dict, List

import numpy as np
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator


[docs]def ngram_overlap_score(source: List[str], example: List[str]) -> float: """计算源文本和示例文本的ngram重叠得分,作为句子BLEU分数,使用NLTK包。 使用带有method1平滑函数和自动重新加权的sentence_bleu。 返回值为0.0到1.0之间的浮点数。 https://www.nltk.org/_modules/nltk/translate/bleu_score.html https://aclanthology.org/P02-1040.pdf """ from nltk.translate.bleu_score import ( SmoothingFunction, # type: ignore sentence_bleu, ) hypotheses = source[0].split() references = [s.split() for s in example] return float( sentence_bleu( references, hypotheses, smoothing_function=SmoothingFunction().method1, auto_reweigh=True, ) )
[docs]class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): """根据ngram重叠得分(来自NLTK软件包的sentence_bleu得分)选择和排序示例。""" examples: List[dict] """提示模板期望的示例列表。""" example_prompt: PromptTemplate """用于格式化示例的提示模板。""" threshold: float = -1.0 """算法停止的阈值。默认设置为-1.0。 对于负阈值: select_examples按ngram_overlap_score对示例进行排序,但不排除任何示例。 对于大于1.0的阈值: select_examples排除所有示例,并返回一个空列表。 对于等于0.0的阈值: select_examples按ngram_overlap_score对示例进行排序, 并排除与输入没有ngram重叠的示例。""" @root_validator(pre=True) def check_dependencies(cls, values: Dict) -> Dict: """检查是否存在有效的依赖关系。""" try: from nltk.translate.bleu_score import ( # noqa: F401 SmoothingFunction, sentence_bleu, ) except ImportError as e: raise ImportError( "Not all the correct dependencies for this ExampleSelect exist." "Please install nltk with `pip install nltk`." ) from e return values
[docs] def add_example(self, example: Dict[str, str]) -> None: """向列表中添加新的示例。""" self.examples.append(example)
[docs] def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """返回根据输入的ngram_overlap_score排序的示例列表。 降序排列。 排除任何ngram_overlap_score小于或等于阈值的示例。 """ inputs = list(input_variables.values()) examples = [] k = len(self.examples) score = [0.0] * k first_prompt_template_key = self.example_prompt.input_variables[0] for i in range(k): score[i] = ngram_overlap_score( inputs, [self.examples[i][first_prompt_template_key]] ) while True: arg_max = np.argmax(score) if (score[arg_max] < self.threshold) or abs( score[arg_max] - self.threshold ) < 1e-9: break examples.append(self.examples[arg_max]) score[arg_max] = self.threshold - 1.0 return examples