ray.data.Dataset.iter_tf_batches#
- Dataset.iter_tf_batches(*, prefetch_batches: int = 1, batch_size: int | None = 256, dtypes: tf.dtypes.DType | Dict[str, tf.dtypes.DType] | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None) Iterable[tf.Tensor | Dict[str, tf.Tensor]] [源代码]#
返回一个迭代器,遍历表示为 TensorFlow 张量的数据批次。
这个可迭代对象产生类型为
Dict[str, tf.Tensor]
的批次。为了获得更多灵活性,可以调用iter_batches()
并手动将数据转换为 TensorFlow 张量。小技巧
如果你不需要这种方法提供的额外灵活性,可以考虑使用
to_tf()
。它更容易使用。备注
此操作将触发对此数据集执行的延迟转换。
示例
import ray ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") tf_dataset = ds.to_tf( feature_columns="sepal length (cm)", label_columns="target", batch_size=2 ) for features, labels in tf_dataset: print(features, labels)
tf.Tensor([5.1 4.9], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64) ... tf.Tensor([6.2 5.9], shape=(2,), dtype=float64) tf.Tensor([2 2], shape=(2,), dtype=int64)
时间复杂度: O(1)
- 参数:
prefetch_batches – 要预取的批次数量,超过当前批次。如果设置为大于0,将使用一个单独的线程池来将对象获取到本地节点,格式化批次,并应用
collate_fn
。默认为1。batch_size – 每个批次中的行数,或
None
以使用整个块作为批次(块可能包含不同数量的行)。如果drop_last
为False
,则最后一个批次可能包含少于batch_size
行。默认为 256。dtypes – 创建的张量(s)的TensorFlow dtype(s);如果
None
,则从张量数据推断dtype。drop_last – 如果最后一个批次不完整,是否丢弃它。
local_shuffle_buffer_size – 如果不是
None
,数据将使用本地内存中的随机洗牌缓冲区进行随机洗牌,并且此值作为本地内存中随机洗牌缓冲区中必须存在的最小行数,以便生成一个批次。当没有更多的行可以添加到缓冲区时,缓冲区中剩余的行将被排空。使用本地洗牌时,还必须指定batch_size
。local_shuffle_seed – 用于本地随机洗牌的种子。
- 返回:
一个遍历 TensorFlow 张量批次的可迭代对象。
参见
Dataset.iter_batches()
调用此方法以手动将您的数据转换为 TensorFlow 张量。