"""链式调用API并总结响应以回答问题。"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
"""从给定的URL中提取方案(scheme)+域名(domain)。
参数:
url(str):输入的URL。
返回:
返回一个包含方案和域名的2元组。
"""
parsed_uri = urlparse(url)
return parsed_uri.scheme, parsed_uri.netloc
def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
"""检查URL是否在允许的域中。
参数:
url(str):输入的URL。
limit_to_domains(Sequence[str]):允许的域。
返回:
bool:如果URL在允许的域中则为True,否则为False。
"""
scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains:
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain:
return True
return False
try:
from langchain_community.utilities.requests import TextRequestsWrapper
class APIChain(Chain):
"""用于进行API调用并总结响应以回答问题的链。
*安全提示*:此API链使用requests工具包
来进行GET、POST、PATCH、PUT和DELETE请求到一个API。
在允许谁使用此链时要小心。如果暴露给最终用户,请考虑用户将能够代表托管代码的服务器进行任意请求。
例如,用户可以要求服务器向只能从服务器访问的私有API发出请求。
控制谁可以使用此工具包提交问题请求以及它具有什么网络访问权限。
请参阅https://python.langchain.com/docs/security获取更多信息。"""
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: TextRequestsWrapper = Field(exclude=True)
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
limit_to_domains: Optional[Sequence[str]]
"""用于限制API链可以访问的域。
* 例如,要限制为只能访问域`https://www.example.com`,请设置`limit_to_domains=["https://www.example.com"]`。
* 默认值是一个空元组,这意味着默认情况下不允许任何域。按设计,这将在实例化时引发错误。
* 如果要默认允许所有域,请使用None --出于安全原因,不建议这样做,因为这将允许恶意用户向服务器可访问的任意URL发出请求,包括内部API。"""
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
return [self.question_key]
@property
def output_keys(self) -> List[str]:
"""期望输出关键字。
:元数据 私有:
"""
return [self.output_key]
@root_validator(pre=True)
def validate_api_request_prompt(cls, values: Dict) -> Dict:
"""检查API请求提示是否期望正确的变量。"""
input_vars = values["api_request_chain"].prompt.input_variables
expected_vars = {"question", "api_docs"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
@root_validator(pre=True)
def validate_limit_to_domains(cls, values: Dict) -> Dict:
"""检查允许的域名是否有效。"""
if "limit_to_domains" not in values:
raise ValueError(
"You must specify a list of domains to limit access using "
"`limit_to_domains`"
)
if (
not values["limit_to_domains"]
and values["limit_to_domains"] is not None
):
raise ValueError(
"Please provide a list of domains to limit access using "
"`limit_to_domains`."
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""检查API答案提示是否期望正确的变量。"""
input_vars = values["api_answer_chain"].prompt.input_variables
expected_vars = {"question", "api_docs", "api_url", "api_response"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question,
api_docs=self.api_docs,
callbacks=_run_manager.get_child(),
)
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = self.requests_wrapper.get(api_url)
_run_manager.on_text(
str(api_response), color="yellow", end="\n", verbose=self.verbose
)
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
callbacks=_run_manager.get_child(),
)
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = (
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
)
question = inputs[self.question_key]
api_url = await self.api_request_chain.apredict(
question=question,
api_docs=self.api_docs,
callbacks=_run_manager.get_child(),
)
await _run_manager.on_text(
api_url, color="green", end="\n", verbose=self.verbose
)
api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = await self.requests_wrapper.aget(api_url)
await _run_manager.on_text(
str(api_response), color="yellow", end="\n", verbose=self.verbose
)
answer = await self.api_answer_chain.apredict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
callbacks=_run_manager.get_child(),
)
return {self.output_key: answer}
[docs] @classmethod
def from_llm_and_api_docs(
cls,
llm: BaseLanguageModel,
api_docs: str,
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
limit_to_domains: Optional[Sequence[str]] = tuple(),
**kwargs: Any,
) -> APIChain:
"""从仅有LLM和api文档中加载链。"""
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt)
requests_wrapper = TextRequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt)
return cls(
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
limit_to_domains=limit_to_domains,
**kwargs,
)
@property
def _chain_type(self) -> str:
return "api_chain"
except ImportError:
[docs] class APIChain: # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ImportError(
"To use the APIChain, you must install the langchain_community package."
"pip install langchain_community"
)