ray.data.Dataset.to_torch#

Dataset.to_torch(*, label_column: str | None = None, feature_columns: List[str] | List[List[str]] | Dict[str, List[str]] | None = None, label_column_dtype: torch.dtype | None = None, feature_column_dtypes: torch.dtype | List[torch.dtype] | Dict[str, torch.dtype] | None = None, batch_size: int = 1, prefetch_batches: int = 1, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None, unsqueeze_label_tensor: bool = True, unsqueeze_feature_tensors: bool = True) torch.utils.data.IterableDataset[源代码]#

返回一个 Torch IterableDataset 覆盖此 Dataset

这仅支持可转换为 Arrow 记录的数据集。

建议直接使用返回的 IterableDataset ,而不是将其传递给 torch 的 DataLoader

IterableDataset 中的每个元素都是一个由两个元素组成的元组。第一个元素包含特征张量,第二个元素是标签张量。这些元素的形式可以根据指定的参数而变化。

对于特征张量(N 是 batch_size,n、m、k 是每个张量的特征数量):

  • 如果 feature_columns 是一个 List[str],那么特征是一个形状为 (N, n) 的张量,其列对应于 feature_columns

  • 如果 feature_columns 是一个 List[List[str]],那么特征是一个形状为 [(N, m),…,(N, k)] 的张量列表,其中每个张量的列对应于 feature_columns 的元素。

  • 如果 feature_columns 是一个 Dict[str, List[str]],那么特征是一个形状为 {key1: (N, m),…, keyN: (N, k)} 的键-张量对的字典,其中每个张量的列对应于 feature_columns 在键下的值。

如果 unsqueeze_label_tensor=True``(默认),标签张量的形状为 (N, 1)。否则,其形状为 (N,)。如果 ``label_column 被指定为 None,那么 Dataset 中的任何列都不会被视为标签,输出标签张量为 None

请注意,如果有多个 Torch 工作者要消费数据,您可能希望在此数据集上调用 Dataset.split()

备注

此操作将触发对此数据集执行的延迟转换。

时间复杂度: O(1)

参数:
  • label_column – 用作标签的列名(输出列表的第二个元素)。可以为 None 用于预测,在这种情况下,返回元组的第二个元素也将为 None。

  • feature_columns – 要使用的列名作为特征。可以是列表的列表,或者是用于多张量输出的字符串-列表对的字典。如果为 None ,则使用除标签列之外的所有列作为特征。

  • label_column_dtype – 用于标签列的 torch 数据类型。如果为 None,则自动推断数据类型。

  • feature_column_dtypes – 用于特征张量的数据类型。这应与 feature_columns 的格式匹配,或为单一数据类型,在这种情况下,它适用于所有张量。如果为 None,则自动推断数据类型。

  • batch_size – 每次产生多少个批次样本。默认为1。

  • prefetch_batches – 要预取的批次数量,超过当前批次。如果设置为大于0,则使用单独的线程池将对象获取到本地节点,格式化批次,并应用collate_fn。默认为1。

  • drop_last – 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为 False 且流的大小不能被批次大小整除,则最后一个批次会更小。默认为 False。

  • local_shuffle_buffer_size – 如果非空,数据将使用本地内存中的随机缓冲区进行随机洗牌,并且此值将作为本地内存中随机缓冲区中必须存在的最小行数,以便生成一个批次。当没有更多的行可以添加到缓冲区时,缓冲区中剩余的行将被排空。此缓冲区大小必须大于或等于 batch_size,因此在使用本地洗牌时也必须指定 batch_size

  • local_shuffle_seed – 用于本地随机洗牌的种子。

  • unsqueeze_label_tensor – 如果设置为 True,标签张量将被 unsqueeze(重塑为 (N, 1))。否则,它将保持原样,即 (N, )。通常,回归损失函数期望一个 unsqueeze 的张量,而分类损失函数期望一个 squeeze 的张量。默认为 True。

  • unsqueeze_feature_tensors – 如果设置为 True,特征张量在连接到最终特征张量之前会被压缩(重塑为 (N, 1))。否则,它们保持原样,即 (N, )。默认为 True。

返回:

一个 Torch IterableDataset