from __future__ import annotations
import asyncio
import inspect
import logging
import re
from typing import (
Callable,
Iterator,
List,
Optional,
Sequence,
Set,
Union,
cast,
)
import aiohttp
import requests
from langchain_core.documents import Document
from langchain_core.utils.html import extract_sub_links
from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
def _metadata_extractor(
raw_html: str, url: str, response: Union[requests.Response, aiohttp.ClientResponse]
) -> dict:
"""使用BeautifulSoup从原始HTML中提取元数据。"""
content_type = getattr(response, "headers").get("Content-Type", "")
metadata = {"source": url, "content_type": content_type}
try:
from bs4 import BeautifulSoup
except ImportError:
logger.warning(
"The bs4 package is required for default metadata extraction. "
"Please install it with `pip install bs4`."
)
return metadata
soup = BeautifulSoup(raw_html, "html.parser")
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", None)
if html := soup.find("html"):
metadata["language"] = html.get("lang", None)
return metadata
[docs]class RecursiveUrlLoader(BaseLoader):
"""从URL页面加载所有子链接。
**安全提示** :此加载器是一个爬虫,将从给定的URL开始爬取,然后扩展到递归爬取子链接。
Web爬虫通常不应该部署具有对任何内部服务器的网络访问权限。
控制谁可以提交爬取请求以及爬虫具有什么网络访问权限。
在爬取过程中,爬虫可能会遇到恶意URL,这可能导致服务器端请求伪造(SSRF)攻击。
为了降低风险,默认情况下,爬虫只会加载与起始URL相同域的URL(通过prevent_outside命名参数控制)。
这将降低SSRF攻击的风险,但不会完全消除。
例如,如果爬取托管了几个站点的主机:
https://some_host/alice_site/
https://some_host/bob_site/
Alice站点上的恶意URL可能导致爬虫向Bob站点的端点发出恶意GET请求。由于这两个站点托管在同一主机上,因此默认情况下不会阻止这样的请求。
请参阅 https://python.langchain.com/docs/security"""
[docs] def __init__(
self,
url: str,
max_depth: Optional[int] = 2,
use_async: Optional[bool] = None,
extractor: Optional[Callable[[str], str]] = None,
metadata_extractor: Optional[_MetadataExtractorType] = None,
exclude_dirs: Optional[Sequence[str]] = (),
timeout: Optional[int] = 10,
prevent_outside: bool = True,
link_regex: Union[str, re.Pattern, None] = None,
headers: Optional[dict] = None,
check_response_status: bool = False,
continue_on_failure: bool = True,
*,
base_url: Optional[str] = None,
autoset_encoding: bool = True,
encoding: Optional[str] = None,
) -> None:
"""初始化爬取的URL和要排除的任何子目录。
参数:
url: 要爬取的URL。
max_depth: 递归加载的最大深度。
use_async: 是否使用异步加载。
如果为True,则此函数不会是懒加载,但仍会按预期方式工作,只是不是懒加载。
extractor: 从原始HTML中提取文档内容的函数。
当提取函数返回空字符串时,将忽略该文档。
metadata_extractor: 从原始HTML、源URL和requests.Response/aiohttp.ClientResponse对象提取元数据的函数
(按照这个顺序的参数)。
默认提取器将尝试使用BeautifulSoup4来提取页面的标题、描述和语言。
..code-block:: python
import requests
import aiohttp
def simple_metadata_extractor(
raw_html: str, url: str, response: Union[requests.Response, aiohttp.ClientResponse]
) -> dict:
content_type = getattr(response, "headers").get("Content-Type", "")
return {"source": url, "content_type": content_type}
exclude_dirs: 要排除的子目录列表。
timeout: 请求超时时间,单位为秒。如果为None,则连接不会超时。
prevent_outside: 如果为True,则阻止从不是根URL的子URL加载。
link_regex: 从网页的原始HTML中提取子链接的正则表达式。
check_response_status: 如果为True,则检查HTTP响应状态并跳过具有错误响应(400-599)的URL。
continue_on_failure: 如果为True,则在获取或解析链接时出现异常时继续。否则,引发异常。
base_url: 用于检查外部链接的基本URL。
autoset_encoding: 是否自动设置响应的编码。
如果为True,则响应的编码将设置为明显的编码,除非已经显式设置了`encoding`参数。
encoding: 响应的编码。如果手动设置,编码将设置为给定值,而不管`autoset_encoding`参数如何。
""" # noqa: E501
self.url = url
self.max_depth = max_depth if max_depth is not None else 2
self.use_async = use_async if use_async is not None else False
self.extractor = extractor if extractor is not None else lambda x: x
metadata_extractor = (
metadata_extractor
if metadata_extractor is not None
else _metadata_extractor
)
self.autoset_encoding = autoset_encoding
self.encoding = encoding
self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor)
self.exclude_dirs = exclude_dirs if exclude_dirs is not None else ()
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
raise ValueError(
f"Base url is included in exclude_dirs. Received base_url: {url} and "
f"exclude_dirs: {self.exclude_dirs}"
)
self.timeout = timeout
self.prevent_outside = prevent_outside if prevent_outside is not None else True
self.link_regex = link_regex
self._lock = asyncio.Lock() if self.use_async else None
self.headers = headers
self.check_response_status = check_response_status
self.continue_on_failure = continue_on_failure
self.base_url = base_url if base_url is not None else url
def _get_child_links_recursive(
self, url: str, visited: Set[str], *, depth: int = 0
) -> Iterator[Document]:
"""递归获取以输入URL路径开头的所有子链接。
参数:
url:要爬取的URL。
visited:已访问URL的集合。
depth:递归的当前深度。当深度 >= 最大深度时停止。
"""
if depth >= self.max_depth:
return
# Get all links that can be accessed from the current URL
visited.add(url)
try:
response = requests.get(url, timeout=self.timeout, headers=self.headers)
if self.encoding is not None:
response.encoding = self.encoding
elif self.autoset_encoding:
response.encoding = response.apparent_encoding
if self.check_response_status and 400 <= response.status_code <= 599:
raise ValueError(f"Received HTTP status {response.status_code}")
except Exception as e:
if self.continue_on_failure:
logger.warning(
f"Unable to load from {url}. Received error {e} of type "
f"{e.__class__.__name__}"
)
return
else:
raise e
content = self.extractor(response.text)
if content:
yield Document(
page_content=content,
metadata=self.metadata_extractor(response.text, url, response),
)
# Store the visited links and recursively visit the children
sub_links = extract_sub_links(
response.text,
url,
base_url=self.base_url,
pattern=self.link_regex,
prevent_outside=self.prevent_outside,
exclude_prefixes=self.exclude_dirs,
continue_on_failure=self.continue_on_failure,
)
for link in sub_links:
# Check all unvisited links
if link not in visited:
yield from self._get_child_links_recursive(
link, visited, depth=depth + 1
)
async def _async_get_child_links_recursive(
self,
url: str,
visited: Set[str],
*,
session: Optional[aiohttp.ClientSession] = None,
depth: int = 0,
) -> List[Document]:
"""递归获取以输入URL路径开头的所有子链接。
参数:
url:要爬取的URL。
visited:已访问过的URL集合。
depth:到达当前URL时,已访问了多少页面。
"""
if not self.use_async or not self._lock:
raise ValueError(
"Async functions forbidden when not initialized with `use_async`"
)
try:
import aiohttp
except ImportError:
raise ImportError(
"The aiohttp package is required for the RecursiveUrlLoader. "
"Please install it with `pip install aiohttp`."
)
if depth >= self.max_depth:
return []
# Disable SSL verification because websites may have invalid SSL certificates,
# but won't cause any security issues for us.
close_session = session is None
session = (
session
if session is not None
else aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=False),
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers=self.headers,
)
)
async with self._lock:
visited.add(url)
try:
async with session.get(url) as response:
text = await response.text()
if self.check_response_status and 400 <= response.status <= 599:
raise ValueError(f"Received HTTP status {response.status}")
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
if close_session:
await session.close()
if self.continue_on_failure:
logger.warning(
f"Unable to load {url}. Received error {e} of type "
f"{e.__class__.__name__}"
)
return []
else:
raise e
results = []
content = self.extractor(text)
if content:
results.append(
Document(
page_content=content,
metadata=self.metadata_extractor(text, url, response),
)
)
if depth < self.max_depth - 1:
sub_links = extract_sub_links(
text,
url,
base_url=self.base_url,
pattern=self.link_regex,
prevent_outside=self.prevent_outside,
exclude_prefixes=self.exclude_dirs,
continue_on_failure=self.continue_on_failure,
)
# Recursively call the function to get the children of the children
sub_tasks = []
async with self._lock:
to_visit = set(sub_links).difference(visited)
for link in to_visit:
sub_tasks.append(
self._async_get_child_links_recursive(
link, visited, session=session, depth=depth + 1
)
)
next_results = await asyncio.gather(*sub_tasks)
for sub_result in next_results:
if isinstance(sub_result, Exception) or sub_result is None:
# We don't want to stop the whole process, so just ignore it
# Not standard html format or invalid url or 404 may cause this.
continue
# locking not fully working, temporary hack to ensure deduplication
results += [r for r in sub_result if r not in results]
if close_session:
await session.close()
return results
[docs] def lazy_load(self) -> Iterator[Document]:
"""延迟加载网页。
当use_async为True时,此函数将不再是延迟加载的,
但它仍将按预期方式工作,只是不再是延迟加载。
"""
visited: Set[str] = set()
if self.use_async:
results = asyncio.run(
self._async_get_child_links_recursive(self.url, visited)
)
return iter(results or [])
else:
return self._get_child_links_recursive(self.url, visited)
_MetadataExtractorType1 = Callable[[str, str], dict]
_MetadataExtractorType2 = Callable[
[str, str, Union[requests.Response, aiohttp.ClientResponse]], dict
]
_MetadataExtractorType = Union[_MetadataExtractorType1, _MetadataExtractorType2]
def _wrap_metadata_extractor(
metadata_extractor: _MetadataExtractorType,
) -> _MetadataExtractorType2:
if len(inspect.signature(metadata_extractor).parameters) == 3:
return cast(_MetadataExtractorType2, metadata_extractor)
else:
def _metadata_extractor_wrapper(
raw_html: str,
url: str,
response: Union[requests.Response, aiohttp.ClientResponse],
) -> dict:
return cast(_MetadataExtractorType1, metadata_extractor)(raw_html, url)
return _metadata_extractor_wrapper