ray.data.Dataset.iter_tf_batches#

Dataset.iter_tf_batches(*, prefetch_batches: int = 1, batch_size: int | None = 256, dtypes: tf.dtypes.DType | Dict[str, tf.dtypes.DType] | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None) Iterable[tf.Tensor | Dict[str, tf.Tensor]][源代码]#

返回一个迭代器,遍历表示为 TensorFlow 张量的数据批次。

这个可迭代对象产生类型为 Dict[str, tf.Tensor] 的批次。为了获得更多灵活性,可以调用 iter_batches() 并手动将数据转换为 TensorFlow 张量。

小技巧

如果你不需要这种方法提供的额外灵活性,可以考虑使用 to_tf() 。它更容易使用。

备注

此操作将触发对此数据集执行的延迟转换。

示例

import ray

ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")

tf_dataset = ds.to_tf(
    feature_columns="sepal length (cm)",
    label_columns="target",
    batch_size=2
)
for features, labels in tf_dataset:
    print(features, labels)
tf.Tensor([5.1 4.9], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64)
...
tf.Tensor([6.2 5.9], shape=(2,), dtype=float64) tf.Tensor([2 2], shape=(2,), dtype=int64)

时间复杂度: O(1)

参数:
  • prefetch_batches – 要预取的批次数量,超过当前批次。如果设置为大于0,将使用一个单独的线程池来将对象获取到本地节点,格式化批次,并应用 collate_fn。默认为1。

  • batch_size – 每个批次中的行数,或 None 以使用整个块作为批次(块可能包含不同数量的行)。如果 drop_lastFalse,则最后一个批次可能包含少于 batch_size 行。默认为 256。

  • dtypes – 创建的张量(s)的TensorFlow dtype(s);如果 None,则从张量数据推断dtype。

  • drop_last – 如果最后一个批次不完整,是否丢弃它。

  • local_shuffle_buffer_size – 如果不是 None ,数据将使用本地内存中的随机洗牌缓冲区进行随机洗牌,并且此值作为本地内存中随机洗牌缓冲区中必须存在的最小行数,以便生成一个批次。当没有更多的行可以添加到缓冲区时,缓冲区中剩余的行将被排空。使用本地洗牌时,还必须指定 batch_size

  • local_shuffle_seed – 用于本地随机洗牌的种子。

返回:

一个遍历 TensorFlow 张量批次的可迭代对象。

参见

Dataset.iter_batches()

调用此方法以手动将您的数据转换为 TensorFlow 张量。