Source code for


import json
from typing import Any, Dict, Optional

from langchain_core.callbacks import (
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from import BaseTool
from langchain_core.vectorstores import VectorStore

from langchain_community.llms.openai import OpenAI

[docs]class BaseVectorStoreTool(BaseModel): """使用VectorStore的工具的基类。""" vectorstore: VectorStore = Field(exclude=True) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) class Config(BaseTool.Config): pass
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: values["description"] = values["template"].format(name=values["name"]) return values
[docs]class VectorStoreQATool(BaseVectorStoreTool, BaseTool): """用于VectorDBQA链的工具。需要使用名称和链进行初始化。"""
[docs] @staticmethod def get_description(name: str, description: str) -> str: template: str = ( "Useful for when you need to answer questions about {name}. " "Whenever you need information about {description} " "you should ALWAYS use this. " "Input should be a fully formed question." ) return template.format(name=name, description=description)
def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """使用这个工具。""" from langchain.chains.retrieval_qa.base import RetrievalQA chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return chain.invoke( {chain.input_key: query}, config={"callbacks": run_manager.get_child() if run_manager else None}, )[chain.output_key] async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """使用该工具进行异步操作。""" from langchain.chains.retrieval_qa.base import RetrievalQA chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return ( await chain.ainvoke( {chain.input_key: query}, config={"callbacks": run_manager.get_child() if run_manager else None}, ) )[chain.output_key]
[docs]class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): """工具用于VectorDBQAWithSources链。"""
[docs] @staticmethod def get_description(name: str, description: str) -> str: template: str = ( "Useful for when you need to answer questions about {name} and the sources " "used to construct the answer. " "Whenever you need information about {description} " "you should ALWAYS use this. " " Input should be a fully formed question. " "Output is a json serialized dictionary with keys `answer` and `sources`. " "Only use this tool if the user explicitly asks for sources." ) return template.format(name=name, description=description)
def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """使用这个工具。""" from langchain.chains.qa_with_sources.retrieval import ( RetrievalQAWithSourcesChain, ) chain = RetrievalQAWithSourcesChain.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps( chain.invoke( {chain.question_key: query}, return_only_outputs=True, config={"callbacks": run_manager.get_child() if run_manager else None}, ) ) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """使用该工具进行异步操作。""" from langchain.chains.qa_with_sources.retrieval import ( RetrievalQAWithSourcesChain, ) chain = RetrievalQAWithSourcesChain.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps( await chain.ainvoke( {chain.question_key: query}, return_only_outputs=True, config={"callbacks": run_manager.get_child() if run_manager else None}, ) )