"""调用you.com搜索API的工具。
要设置这个工具,请按照以下说明操作:
"""
from typing import Any, Dict, List, Literal, Optional
import aiohttp
import requests
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
YOU_API_URL = "https://api.ydc-index.io"
[docs]class YouHit(YouHitMetadata):
"""来自you.com的单个点击,可能包含多个片段"""
snippets: List[str] = Field(description="One or snippets of text")
[docs]class YouAPIOutput(BaseModel):
"""来自you.com API的输出。"""
hits: List[YouHit] = Field(
description="A list of dictionaries containing the results"
)
[docs]class YouDocument(BaseModel):
"""解析一个片段的输出。"""
page_content: str = Field(description="One snippet of text")
metadata: YouHitMetadata
[docs]class YouSearchAPIWrapper(BaseModel):
"""包装器,用于you.com搜索API。
要连接到You.com API,需要一个API密钥,您可以在https://api.you.com获取。
您可以在https://documentation.you.com查看文档。
您需要设置环境变量`YDC_API_KEY`以使检索器运行。
属性
----------
ydc_api_key: str, optional
you.com API密钥,如果环境中未设置YDC_API_KEY
num_web_results: int, optional
要返回的Web结果的最大数量,必须小于20
safesearch: str, optional
安全搜索设置,可选值为off、moderate、strict,默认为moderate
country: str, optional
国家代码,例如:'US'代表美国,请参阅API文档以获取列表
k: int, optional
使用`results()`返回的文档的最大数量
n_hits: int, optional, deprecated
num_web_results的别名
n_snippets_per_hit: int, optional
限制每个命中返回的摘要数量
endpoint_type: str, optional
you.com端点:search、news、rag;
`web`和`snippet`别名`search`
`rag`返回`{'message': 'Forbidden'}`
@todo `news`端点"""
ydc_api_key: Optional[str] = None
num_web_results: Optional[int] = None
safesearch: Optional[str] = None
country: Optional[str] = None
k: Optional[int] = None
n_snippets_per_hit: Optional[int] = None
# @todo deprecate `snippet`, not part of API
endpoint_type: Literal["search", "news", "rag", "snippet"] = "search"
# should deprecate n_hits
n_hits: Optional[int] = None
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥。"""
ydc_api_key = get_from_dict_or_env(values, "ydc_api_key", "YDC_API_KEY")
values["ydc_api_key"] = ydc_api_key
return values
def _parse_results(self, raw_search_results: Dict) -> List[Document]:
"""从每个命中中提取片段并将它们放入文档中
参数:
raw_search_results: 包含命中列表的字典
返回:
List[YouDocument]: 解析结果的字典
"""
# return news results
if self.endpoint_type == "news":
return [
Document(page_content=result["description"], metadata=result)
for result in raw_search_results["news"]["results"]
]
docs = []
for hit in raw_search_results["hits"]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit.get("snippets"))
for snippet in hit.get("snippets")[:n_snippets_per_hit]:
docs.append(
Document(
page_content=snippet,
metadata={
"url": hit.get("url"),
"thumbnail_url": hit.get("thumbnail_url"),
"title": hit.get("title"),
"description": hit.get("description"),
},
)
)
if self.k is not None and len(docs) >= self.k:
return docs
return docs
[docs] def raw_results(
self,
query: str,
**kwargs: Any,
) -> Dict:
"""运行查询通过you.com搜索并返回命中结果。
参数:
query: 要搜索的查询。
num_web_results: 要返回的最大结果数。
safesearch: 安全搜索设置,
可选值为off、moderate、strict,默认为moderate。
country: 国家代码
返回:YouAPIOutput
"""
headers = {"X-API-Key": self.ydc_api_key or ""}
params = {
"query": query,
"num_web_results": self.num_web_results,
"safesearch": self.safesearch,
"country": self.country,
**kwargs,
}
params = {k: v for k, v in params.items() if v is not None}
# news endpoint expects `q` instead of `query`
if self.endpoint_type == "news":
params["q"] = params["query"]
del params["query"]
# @todo deprecate `snippet`, not part of API
if self.endpoint_type == "snippet":
self.endpoint_type = "search"
response = requests.get(
# type: ignore
f"{YOU_API_URL}/{self.endpoint_type}",
params=params,
headers=headers,
)
response.raise_for_status()
return response.json()
[docs] def results(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
"""通过you.com搜索运行查询并将结果解析为文档。"""
raw_search_results = self.raw_results(
query,
**{key: value for key, value in kwargs.items() if value is not None},
)
return self._parse_results(raw_search_results)
[docs] async def raw_results_async(
self,
query: str,
**kwargs: Any,
) -> Dict:
"""从you.com搜索API异步获取结果。"""
headers = {"X-API-Key": self.ydc_api_key or ""}
params = {
"query": query,
"num_web_results": self.num_web_results,
"safesearch": self.safesearch,
"country": self.country,
**kwargs,
}
params = {k: v for k, v in params.items() if v is not None}
# news endpoint expects `q` instead of `query`
if self.endpoint_type == "news":
params["q"] = params["query"]
del params["query"]
# @todo deprecate `snippet`, not part of API
if self.endpoint_type == "snippet":
self.endpoint_type = "search"
async with aiohttp.ClientSession() as session:
async with session.get(
url=f"{YOU_API_URL}/{self.endpoint_type}",
params=params,
headers=headers,
) as res:
if res.status == 200:
results = await res.json()
return results
else:
raise Exception(f"Error {res.status}: {res.reason}")
[docs] async def results_async(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
raw_search_results_async = await self.raw_results_async(
query,
**{key: value for key, value in kwargs.items() if value is not None},
)
return self._parse_results(raw_search_results_async)