Source code for langchain.retrievers.multi_query

import asyncio
import logging
from typing import List, Optional, Sequence

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable

from langchain.chains.llm import LLMChain

logger = logging.getLogger(__name__)


[docs]class LineListOutputParser(BaseOutputParser[List[str]]): """对一系列行的输出解析器。"""
[docs] def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return lines
# Default prompt DEFAULT_QUERY_PROMPT = PromptTemplate( input_variables=["question"], template="""You are an AI language model assistant. Your task is to generate 3 different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of distance-based similarity search. Provide these alternative questions separated by newlines. Original question: {question}""", ) def _unique_documents(documents: Sequence[Document]) -> List[Document]: return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
[docs]class MultiQueryRetriever(BaseRetriever): """给定一个查询,使用LLM编写一组查询。 检索每个查询的文档。返回所有检索到的文档的唯一并集。""" retriever: BaseRetriever llm_chain: Runnable verbose: bool = True parser_key: str = "lines" """已弃用。parser_key不再使用,不应指定。""" include_original: bool = False """是否在生成的查询列表中包含原始查询。"""
[docs] @classmethod def from_llm( cls, retriever: BaseRetriever, llm: BaseLanguageModel, prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, parser_key: Optional[str] = None, include_original: bool = False, ) -> "MultiQueryRetriever": """使用默认模板从llm初始化。 参数: retriever: 用于从中查询文档的检索器 llm: 使用DEFAULT_QUERY_PROMPT进行查询生成的llm include_original: 是否在生成的查询列表中包含原始查询。 返回: MultiQueryRetriever """ output_parser = LineListOutputParser() llm_chain = prompt | llm | output_parser return cls( retriever=retriever, llm_chain=llm_chain, include_original=include_original, )
async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, ) -> List[Document]: """给定用户查询,获取相关文档。 参数: question: 用户查询 返回: 从所有生成的查询中获取相关文档的唯一并集 """ queries = await self.agenerate_queries(query, run_manager) if self.include_original: queries.append(query) documents = await self.aretrieve_documents(queries, run_manager) return self.unique_union(documents)
[docs] async def agenerate_queries( self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[str]: """根据用户输入生成查询。 参数: question: 用户查询 返回: 生成的类似于用户输入的LLM查询列表 """ response = await self.llm_chain.ainvoke( {"question": question}, config={"callbacks": run_manager.get_child()} ) if isinstance(self.llm_chain, LLMChain): lines = response["text"] else: lines = response if self.verbose: logger.info(f"Generated queries: {lines}") return lines
[docs] async def aretrieve_documents( self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """运行所有LLM生成的查询。 参数: queries: 查询列表 返回: 检索到的文档列表 """ document_lists = await asyncio.gather( *( self.retriever.ainvoke( query, config={"callbacks": run_manager.get_child()} ) for query in queries ) ) return [doc for docs in document_lists for doc in docs]
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """给定用户查询,获取相关文档。 参数: question: 用户查询 返回: 从所有生成的查询中获取相关文档的唯一并集 """ queries = self.generate_queries(query, run_manager) if self.include_original: queries.append(query) documents = self.retrieve_documents(queries, run_manager) return self.unique_union(documents)
[docs] def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun ) -> List[str]: """根据用户输入生成查询。 参数: question: 用户查询 返回: 生成的类似于用户输入的LLM查询列表 """ response = self.llm_chain.invoke( {"question": question}, config={"callbacks": run_manager.get_child()} ) if isinstance(self.llm_chain, LLMChain): lines = response["text"] else: lines = response if self.verbose: logger.info(f"Generated queries: {lines}") return lines
[docs] def retrieve_documents( self, queries: List[str], run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """运行所有LLM生成的查询。 参数: queries: 查询列表 返回: 检索到的文档列表 """ documents = [] for query in queries: docs = self.retriever.invoke( query, config={"callbacks": run_manager.get_child()} ) documents.extend(docs) return documents
[docs] def unique_union(self, documents: List[Document]) -> List[Document]: """获取唯一的文档。 参数: documents:检索到的文档列表 返回: 唯一的检索到的文档列表 """ return _unique_documents(documents)