Source code for langchain.chains.qa_generation.base

from __future__ import annotations

import json
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR


[docs]class QAGenerationChain(Chain): """用于问题-答案生成链的基类。""" llm_chain: LLMChain """LLM链,用于根据用户输入和上下文生成响应。""" text_splitter: TextSplitter = Field( default=RecursiveCharacterTextSplitter(chunk_overlap=500) ) """文本分割器,将输入分割成块。""" input_key: str = "text" """输入到链中的键。""" output_key: str = "questions" """链的输出键。""" k: Optional[int] = None """生成问题的数量。"""
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> QAGenerationChain: """从语言模型创建一个QAGenerationChain。 参数: llm: 一个语言模型 prompt: 一个提示模板 **kwargs: 额外的参数 返回: 一个QAGenerationChain类 """ _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) chain = LLMChain(llm=llm, prompt=_prompt) return cls(llm_chain=chain, **kwargs)
@property def _chain_type(self) -> str: raise NotImplementedError @property def input_keys(self) -> List[str]: return [self.input_key] @property def output_keys(self) -> List[str]: return [self.output_key] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, List]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) results = self.llm_chain.generate( [{"text": d.page_content} for d in docs], run_manager=run_manager ) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa}