Source code for langchain.chains.openai_functions.qa_with_structure
from typing import Any, List, Optional, Type, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.output_parsers.openai_functions import (
OutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
[docs]class AnswerWithSources(BaseModel):
"""回答问题,并附上来源。"""
answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field(
..., description="List of sources used to answer the question"
)
[docs]def create_qa_with_structure_chain(
llm: BaseLanguageModel,
schema: Union[dict, Type[BaseModel]],
output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
verbose: bool = False,
) -> LLMChain:
"""创建一个问题回答链,根据模式返回带有来源的答案。
参数:
llm: 用于链的语言模型。
schema: 用于输出的Pydantic模式。
output_parser: 要使用的输出解析器。应为`pydantic`或`base`之一。
默认为`base`。
prompt: 用于链的可选提示。
返回值:
"""
if output_parser == "pydantic":
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
raise ValueError(
"Must provide a pydantic class for schema when output_parser is "
"'pydantic'."
)
_output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser(
pydantic_schema=schema
)
elif output_parser == "base":
_output_parser = OutputFunctionsParser()
else:
raise ValueError(
f"Got unexpected output_parser: {output_parser}. "
f"Should be one of `pydantic` or `base`."
)
if isinstance(schema, type) and issubclass(schema, BaseModel):
schema_dict = schema.schema()
else:
schema_dict = schema
function = {
"name": schema_dict["title"],
"description": schema_dict["description"],
"parameters": schema_dict,
}
llm_kwargs = get_llm_kwargs(function)
messages = [
SystemMessage(
content=(
"You are a world class algorithm to answer "
"questions in a specific format."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=_output_parser,
verbose=verbose,
)
return chain
[docs]def create_qa_with_sources_chain(
llm: BaseLanguageModel, verbose: bool = False, **kwargs: Any
) -> LLMChain:
"""创建一个问题回答链,返回带有来源的答案。
参数:
llm:用于该链的语言模型。
verbose:是否打印链的详细信息
**kwargs:传递给`create_qa_with_structure_chain`的关键字参数。
返回:
可用于带有引用的问题回答的链(LLMChain)。
"""
return create_qa_with_structure_chain(
llm, AnswerWithSources, verbose=verbose, **kwargs
)