Source code for langchain_community.utilities.tensorflow_datasets
import logging
from typing import Any, Callable, Dict, Iterator, List, Optional
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
logger = logging.getLogger(__name__)
[docs]class TensorflowDatasets(BaseModel):
"""访问 TensorFlow 数据集。
当前实现仅适用于适合内存的数据集。
`TensorFlow 数据集` 是一个准备好供 TensorFlow 或其他 Python 机器学习框架(如 Jax)使用的数据集集合。所有数据集都暴露为 `tf.data.Datasets`。
要开始,请查看指南:https://www.tensorflow.org/datasets/overview 和数据集列表:https://www.tensorflow.org/datasets/catalog/overview#all_datasets
您必须提供 sample_to_document_function:一个将数据集特定格式的样本转换为文档的函数。
属性:
dataset_name:要加载的数据集的名称
split_name:要加载的拆分的名称。默认为“train”。
load_max_docs:加载文档数量的限制。默认为 100。
sample_to_document_function:将数据集样本转换为文档的函数
示例:
.. code-block:: python
from langchain_community.utilities import TensorflowDatasets
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasets(
dataset_name="mlqa/en",
split_name="train",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)"""
dataset_name: str = ""
split_name: str = "train"
load_max_docs: int = 100
sample_to_document_function: Optional[Callable[[Dict], Document]] = None
dataset: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证Python包是否存在于环境中。"""
try:
import tensorflow # noqa: F401
except ImportError:
raise ImportError(
"Could not import tensorflow python package. "
"Please install it with `pip install tensorflow`."
)
try:
import tensorflow_datasets
except ImportError:
raise ImportError(
"Could not import tensorflow_datasets python package. "
"Please install it with `pip install tensorflow-datasets`."
)
if values["sample_to_document_function"] is None:
raise ValueError(
"sample_to_document_function is None. "
"Please provide a function that converts a dataset sample to"
" a Document."
)
values["dataset"] = tensorflow_datasets.load(
values["dataset_name"], split=values["split_name"]
)
return values
[docs] def lazy_load(self) -> Iterator[Document]:
"""下载所选数据集的懒加载方式。
返回:一个文档的迭代器。
"""
return (
self.sample_to_document_function(s)
for s in self.dataset.take(self.load_max_docs)
if self.sample_to_document_function is not None
)
[docs] def load(self) -> List[Document]:
"""下载所选数据集。
返回:文档列表。
"""
return list(self.lazy_load())