Keras 3 API 文档 / 数据加载 / 文本数据加载

文本数据加载

[source]

text_dataset_from_directory function

keras.utils.text_dataset_from_directory(
    directory,
    labels="inferred",
    label_mode="int",
    class_names=None,
    batch_size=32,
    max_length=None,
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    follow_links=False,
    verbose=True,
)

生成一个从目录中的文本文件创建的tf.data.Dataset.

如果你的目录结构是:

main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt

那么调用text_dataset_from_directory(main_directory, labels='inferred')将返回一个tf.data.Dataset,该数据集生成来自子目录class_aclass_b的文本批次,以及标签0和1(0对应于class_a,1对应于class_b).

目前仅支持.txt文件.

参数: directory: 数据所在的目录. 如果labels"inferred",它应该包含 子目录,每个子目录包含一个类别的文本文件. 否则,目录结构将被忽略. labels: 可以是"inferred" (标签从目录结构生成), None(无标签), 或一个整数标签的列表/元组,其大小与目录中找到的文本文件数量相同.标签应根据文本文件路径的字母数字顺序进行排序 (通过Python中的os.walk(directory)获得). label_mode: 描述labels编码的字符串.选项有: - "int":表示标签编码为整数 (例如用于sparse_categorical_crossentropy损失). - "categorical":表示标签编码为分类向量 (例如用于categorical_crossentropy损失). - "binary":表示标签(只能有2个) 编码为值为0或1的float32标量 (例如用于binary_crossentropy). - None(无标签). class_names: 仅在"labels""inferred"时有效. 这是类名的显式列表 (必须与子目录名称匹配).用于控制类的顺序(否则使用字母数字顺序). batch_size: 数据的批次大小. 如果为None,数据将不会分批 (数据集将生成单个样本). 默认为32. max_length: 文本字符串的最大长度.超过此长度的文本将被截断至max_length. shuffle: 是否打乱数据. 如果设置为False,数据将按字母数字顺序排序. 默认为True. seed: 用于打乱和变换的可选随机种子. validation_split: 可选的介于0和1之间的浮点数, 用于保留为验证的数据比例. subset: 要返回的数据子集. 可以是"training""validation""both". 仅在设置validation_split时使用. 当subset="both"时,该工具返回两个数据集的元组 (分别是训练和验证数据集). follow_links: 是否访问由符号链接指向的子目录. 默认为False. verbose: 是否显示类和找到的文件数量的信息.默认为True.

返回:

一个tf.data.Dataset对象.

  • 如果label_modeNone,它生成形状为(batch_size,)string张量,包含一批文本文件的内容.
  • 否则,它生成一个元组(texts, labels),其中texts 形状为(batch_size,),labels遵循下面描述的格式.

标签格式规则:

  • 如果label_modeint,标签是形状为(batch_size,)int32张量.
  • 如果label_modebinary,标签是形状为(batch_size, 1)float32张量,值为1和0.
  • 如果label_modecategorical,标签是形状为(batch_size, num_classes)float32张量,表示类索引的独热编码.