import concurrent
import logging
import random
from pathlib import Path
from typing import Any, Callable, Iterator, List, Optional, Sequence, Type, Union
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders.html_bs import BSHTMLLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
FILE_LOADER_TYPE = Union[
Type[UnstructuredFileLoader], Type[TextLoader], Type[BSHTMLLoader], Type[CSVLoader]
]
logger = logging.getLogger(__name__)
def _is_visible(p: Path) -> bool:
parts = p.parts
for _p in parts:
if _p.startswith("."):
return False
return True
[docs]class DirectoryLoader(BaseLoader):
"""从一个目录加载。"""
[docs] def __init__(
self,
path: str,
glob: str = "**/[!.]*",
silent_errors: bool = False,
load_hidden: bool = False,
loader_cls: FILE_LOADER_TYPE = UnstructuredFileLoader,
loader_kwargs: Union[dict, None] = None,
recursive: bool = False,
show_progress: bool = False,
use_multithreading: bool = False,
max_concurrency: int = 4,
*,
exclude: Union[Sequence[str], str] = (),
sample_size: int = 0,
randomize_sample: bool = False,
sample_seed: Union[int, None] = None,
):
"""初始化目录路径和glob匹配模式。
参数:
path: 目录路径。
glob: 用于查找文件的glob模式。默认为"**/[!.]*"(除了隐藏文件之外的所有文件)。
exclude: 要从结果中排除的模式或模式列表。使用glob语法。
silent_errors: 是否静默忽略错误。默认为False。
load_hidden: 是否加载隐藏文件。默认为False。
loader_cls: 用于加载文件的加载器类。默认为UnstructuredFileLoader。
loader_kwargs: 传递给loader_cls的关键字参数。默认为None。
recursive: 是否递归搜索文件。默认为False。
show_progress: 是否显示进度条。默认为False。
use_multithreading: 是否使用多线程。默认为False。
max_concurrency: 要使用的最大线程数。默认为4。
sample_size: 您希望从目录中加载的文件的最大数量。
randomize_sample: 对文件进行洗牌以获得随机样本。
sample_seed: 设置用于可重现性的随机洗牌的种子。
示例:
.. code-block:: python
from langchain_community.document_loaders import DirectoryLoader
# 加载目录中的所有非隐藏文件。
loader = DirectoryLoader("/path/to/directory")
# 加载目录中的所有文本文件,不进行递归。
loader = DirectoryLoader("/path/to/directory", glob="*.txt")
# 递归加载目录中的所有文本文件。
loader = DirectoryLoader(
"/path/to/directory", glob="*.txt", recursive=True
)
# 加载目录中的所有文件,但不包括py文件。
loader = DirectoryLoader("/path/to/directory", exclude="*.py")
# 加载目录中的所有文件,但不包括py或pyc文件。
loader = DirectoryLoader(
"/path/to/directory", exclude=["*.py", "*.pyc"]
)
"""
if loader_kwargs is None:
loader_kwargs = {}
if isinstance(exclude, str):
exclude = (exclude,)
self.path = path
self.glob = glob
self.exclude = exclude
self.load_hidden = load_hidden
self.loader_cls = loader_cls
self.loader_kwargs = loader_kwargs
self.silent_errors = silent_errors
self.recursive = recursive
self.show_progress = show_progress
self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.sample_seed = sample_seed
[docs] def load(self) -> List[Document]:
"""加载文档。"""
return list(self.lazy_load())
[docs] def lazy_load(self) -> Iterator[Document]:
"""懒加载文档。"""
p = Path(self.path)
if not p.exists():
raise FileNotFoundError(f"Directory not found: '{self.path}'")
if not p.is_dir():
raise ValueError(f"Expected directory, got file: '{self.path}'")
paths = p.rglob(self.glob) if self.recursive else p.glob(self.glob)
items = [
path
for path in paths
if not (self.exclude and any(path.match(glob) for glob in self.exclude))
and path.is_file()
]
if self.sample_size > 0:
if self.randomize_sample:
randomizer = random.Random(
self.sample_seed if self.sample_seed else None
)
randomizer.shuffle(items)
items = items[: min(len(items), self.sample_size)]
pbar = None
if self.show_progress:
try:
from tqdm import tqdm
pbar = tqdm(total=len(items))
except ImportError as e:
logger.warning(
"To log the progress of DirectoryLoader you need to install tqdm, "
"`pip install tqdm`"
)
if self.silent_errors:
logger.warning(e)
else:
raise ImportError(
"To log the progress of DirectoryLoader "
"you need to install tqdm, "
"`pip install tqdm`"
)
if self.use_multithreading:
futures = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_concurrency
) as executor:
for i in items:
futures.append(
executor.submit(
self._lazy_load_file_to_non_generator(self._lazy_load_file),
i,
p,
pbar,
)
)
for future in concurrent.futures.as_completed(futures):
for item in future.result():
yield item
else:
for i in items:
yield from self._lazy_load_file(i, p, pbar)
if pbar:
pbar.close()
def _lazy_load_file_to_non_generator(self, func: Callable) -> Callable:
def non_generator(item: Path, path: Path, pbar: Optional[Any]) -> List:
return [x for x in func(item, path, pbar)]
return non_generator
def _lazy_load_file(
self, item: Path, path: Path, pbar: Optional[Any]
) -> Iterator[Document]:
"""加载一个文件。
参数:
item: 文件路径。
path: 目录路径。
pbar: 进度条。默认为None。
"""
if item.is_file():
if _is_visible(item.relative_to(path)) or self.load_hidden:
try:
logger.debug(f"Processing file: {str(item)}")
loader = self.loader_cls(str(item), **self.loader_kwargs)
try:
for subdoc in loader.lazy_load():
yield subdoc
except NotImplementedError:
for subdoc in loader.load():
yield subdoc
except Exception as e:
if self.silent_errors:
logger.warning(f"Error loading file {str(item)}: {e}")
else:
logger.error(f"Error loading file {str(item)}")
raise e
finally:
if pbar:
pbar.update(1)