Source code for langchain_community.utilities.tensorflow_datasets

import logging
from typing import Any, Callable, Dict, Iterator, List, Optional

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)


[docs]class TensorflowDatasets(BaseModel): """访问 TensorFlow 数据集。 当前实现仅适用于适合内存的数据集。 `TensorFlow 数据集` 是一个准备好供 TensorFlow 或其他 Python 机器学习框架(如 Jax)使用的数据集集合。所有数据集都暴露为 `tf.data.Datasets`。 要开始,请查看指南:https://www.tensorflow.org/datasets/overview 和数据集列表:https://www.tensorflow.org/datasets/catalog/overview#all_datasets 您必须提供 sample_to_document_function:一个将数据集特定格式的样本转换为文档的函数。 属性: dataset_name:要加载的数据集的名称 split_name:要加载的拆分的名称。默认为“train”。 load_max_docs:加载文档数量的限制。默认为 100。 sample_to_document_function:将数据集样本转换为文档的函数 示例: .. code-block:: python from langchain_community.utilities import TensorflowDatasets def mlqaen_example_to_document(example: dict) -> Document: return Document( page_content=decode_to_str(example["context"]), metadata={ "id": decode_to_str(example["id"]), "title": decode_to_str(example["title"]), "question": decode_to_str(example["question"]), "answer": decode_to_str(example["answers"]["text"][0]), }, ) tsds_client = TensorflowDatasets( dataset_name="mlqa/en", split_name="train", load_max_docs=MAX_DOCS, sample_to_document_function=mlqaen_example_to_document, )""" dataset_name: str = "" split_name: str = "train" load_max_docs: int = 100 sample_to_document_function: Optional[Callable[[Dict], Document]] = None dataset: Any #: :meta private: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证Python包是否存在于环境中。""" try: import tensorflow # noqa: F401 except ImportError: raise ImportError( "Could not import tensorflow python package. " "Please install it with `pip install tensorflow`." ) try: import tensorflow_datasets except ImportError: raise ImportError( "Could not import tensorflow_datasets python package. " "Please install it with `pip install tensorflow-datasets`." ) if values["sample_to_document_function"] is None: raise ValueError( "sample_to_document_function is None. " "Please provide a function that converts a dataset sample to" " a Document." ) values["dataset"] = tensorflow_datasets.load( values["dataset_name"], split=values["split_name"] ) return values
[docs] def lazy_load(self) -> Iterator[Document]: """下载所选数据集的懒加载方式。 返回:一个文档的迭代器。 """ return ( self.sample_to_document_function(s) for s in self.dataset.take(self.load_max_docs) if self.sample_to_document_function is not None )
[docs] def load(self) -> List[Document]: """下载所选数据集。 返回:文档列表。 """ return list(self.lazy_load())