加载数据#

Ray Data 从各种来源加载数据。本指南将向您展示如何:

读取文件#

Ray Data 从本地磁盘或云存储中读取各种文件格式的文件。要查看支持的文件格式的完整列表,请参阅 输入/输出参考

要读取 Parquet 文件,请调用 read_parquet()

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

要读取原始图像,请调用 read_images()。Ray Data 将图像表示为 NumPy ndarrays。

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages/")

print(ds.schema())
Column  Type
------  ----
image   numpy.ndarray(shape=(32, 32, 3), dtype=uint8)

要读取文本行,请调用 read_text()

import ray

ds = ray.data.read_text("s3://anonymous@ray-example-data/this.txt")

print(ds.schema())
Column  Type
------  ----
text    string

要读取CSV文件,调用 read_csv()

import ray

ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")

print(ds.schema())
Column             Type
------             ----
sepal length (cm)  double
sepal width (cm)   double
petal length (cm)  double
petal width (cm)   double
target             int64

要读取原始二进制文件,请调用 read_binary_files()

import ray

ds = ray.data.read_binary_files("s3://anonymous@ray-example-data/documents")

print(ds.schema())
Column  Type
------  ----
bytes   binary

要读取 TFRecords 文件,请调用 read_tfrecords()

import ray

ds = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")

print(ds.schema())
Column        Type
------        ----
label         binary
petal.length  float
sepal.width   float
petal.width   float
sepal.length  float

从本地磁盘读取文件#

要从本地磁盘读取文件,调用类似 read_parquet() 的函数,并使用 local:// 模式指定路径。路径可以指向文件或目录。

要读取除 Parquet 以外的格式,请参阅 输入/输出参考

小技巧

如果你的文件在每个节点上都可以访问,排除 local:// 以在集群中并行化读取任务。

import ray

ds = ray.data.read_parquet("local:///tmp/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

从云存储读取文件#

要读取云存储中的文件,请使用您的云服务提供商对所有节点进行身份验证。然后,调用类似 read_parquet() 的方法,并指定带有适当模式的URI。URI可以指向存储桶、文件夹或对象。

要读取除 Parquet 以外的格式,请参阅 输入/输出参考

要从 Amazon S3 读取文件,请使用 s3:// 方案指定 URI。

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行与 Amazon S3 的身份验证。有关如何配置您的凭证以兼容 PyArrow 的更多信息,请参阅他们的 S3 文件系统文档

要从 Google Cloud Storage 读取文件,请安装 Google Cloud Storage 的文件系统接口

pip install gcsfs

然后,创建一个 GCSFileSystem 并使用 gcs:// 方案指定URI。

import ray

filesystem = gcsfs.GCSFileSystem(project="my-google-project")
ds = ray.data.read_parquet(
    "gcs://anonymous@ray-example-data/iris.parquet",
    filesystem=filesystem
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行与 Google Cloud Storage 的身份验证。有关如何配置您的凭据以兼容 PyArrow 的更多信息,请参阅他们的 GCS 文件系统文档

要从 Azure Blob 存储读取文件,请安装 Azure-Datalake Gen1 和 Gen2 存储的文件系统接口

pip install adlfs

然后,创建一个 AzureBlobFileSystem 并使用 az:// 方案指定 URI。

import adlfs
import ray

ds = ray.data.read_parquet(
    "az://ray-example-data/iris.parquet",
    adlfs.AzureBlobFileSystem(account_name="azureopendatastorage")
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖于 PyArrow 进行 Azure Blob 存储的身份验证。有关如何配置您的凭证以兼容 PyArrow 的更多信息,请参阅他们的 fsspec 兼容文件系统文档

从NFS读取文件#

要从NFS文件系统读取文件,调用类似 read_parquet() 的函数,并在挂载的文件系统上指定文件。路径可以指向文件或目录。

要读取除 Parquet 以外的格式,请参阅 输入/输出参考

import ray

ds = ray.data.read_parquet("/mnt/cluster_storage/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

处理压缩文件#

要读取压缩文件,请在 arrow_open_stream_args 中指定 compression。您可以使用 Arrow 支持的任何编解码器

import ray

ds = ray.data.read_csv(
    "s3://anonymous@ray-example-data/iris.csv.gz",
    arrow_open_stream_args={"compression": "gzip"},
)

从其他库加载数据#

从单节点数据库加载数据#

Ray Data 与 pandas、NumPy 和 Arrow 等库互操作。

要从 Python 对象创建 Dataset,请调用 from_items() 并传入一个 Dict 列表。Ray Data 将每个 Dict 视为一行。

import ray

ds = ray.data.from_items([
    {"food": "spam", "price": 9.34},
    {"food": "ham", "price": 5.37},
    {"food": "eggs", "price": 0.94}
])

print(ds)
MaterializedDataset(
   num_blocks=3,
   num_rows=3,
   schema={food: string, price: double}
)

你也可以从一个常规的 Python 对象列表创建一个 Dataset

import ray

ds = ray.data.from_items([1, 2, 3, 4, 5])

print(ds)
MaterializedDataset(num_blocks=5, num_rows=5, schema={item: int64})

要从一个 NumPy 数组创建 Dataset,请调用 from_numpy()。Ray Data 将外轴视为行维度。

import numpy as np
import ray

array = np.ones((3, 2, 2))
ds = ray.data.from_numpy(array)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={data: numpy.ndarray(shape=(2, 2), dtype=double)}
)

要从 pandas DataFrame 创建 Dataset,请调用 from_pandas()

import pandas as pd
import ray

df = pd.DataFrame({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_pandas(df)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: object, price: float64}
)

要从 Arrow 表创建 Dataset ,请调用 from_arrow()

import pyarrow as pa

table = pa.table({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_arrow(table)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: string, price: double}
)

从分布式DataFrame库加载数据#

Ray Data 与分布式数据处理框架如 DaskSparkModinMars 相互操作。

备注

Ray 社区提供这些操作,但可能不会积极维护它们。如果你遇到问题,请在 这里 创建一个 GitHub 问题。

要从 Dask DataFrame 创建 Dataset,请调用 from_dask()。此函数构建一个由支持 Dask DataFrame 的分布式 Pandas DataFrame 分区支持的 Dataset

import dask.dataframe as dd
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
ddf = dd.from_pandas(df, npartitions=4)
# Create a Dataset from a Dask DataFrame.
ds = ray.data.from_dask(ddf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Spark DataFrame 创建一个 Dataset,请调用 from_spark()。此函数创建一个由支持 Spark DataFrame 的分布式 Spark DataFrame 分区支持的 Dataset

import ray
import raydp

spark = raydp.init_spark(app_name="Spark -> Datasets Example",
                        num_executors=2,
                        executor_cores=2,
                        executor_memory="500MB")
df = spark.createDataFrame([(i, str(i)) for i in range(10000)], ["col1", "col2"])
ds = ray.data.from_spark(df)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Iceberg Table 创建一个 Dataset,请调用 read_iceberg()。此函数创建一个由 Iceberg 表底层分布式文件支持的 Dataset

>>> import ray
>>> from pyiceberg.expressions import EqualTo
>>> ds = ray.data.read_iceberg(
...     table_identifier="db_name.table_name",
...     row_filter=EqualTo("column_name", "literal_value"),
...     catalog_kwargs={"name": "default", "type": "glue"}
... )
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Modin DataFrame 创建 Dataset,请调用 from_modin()。此函数构建一个由支持 Modin DataFrame 的分布式 Pandas DataFrame 分区支持的 Dataset

import modin.pandas as md
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df)
# Create a Dataset from a Modin DataFrame.
ds = ray.data.from_modin(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Mars DataFrame 创建 Dataset,请调用 from_mars()。此函数构建一个由 Mars DataFrame 底层分布式 Pandas DataFrame 分区支持的 Dataset

import mars
import mars.dataframe as md
import pandas as pd
import ray

cluster = mars.new_cluster_in_ray(worker_num=2, worker_cpu=1)

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df, num_partitions=8)
# Create a tabular Dataset from a Mars DataFrame.
ds = ray.data.from_mars(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

从ML库加载数据#

Ray Data 与 HuggingFace、PyTorch 和 TensorFlow 数据集互操作。

要将 HuggingFace 数据集转换为 Ray 数据集,请调用 from_huggingface()。此函数访问底层 Arrow 表并将其直接转换为数据集。

警告

from_huggingface 仅在某些情况下支持并行读取,即对于未转换的公共 HuggingFace 数据集。对于这些数据集,Ray Data 使用 托管的 parquet 文件 来执行分布式读取;否则,Ray Data 使用单节点读取。这种行为对于内存中的 HuggingFace 数据集不应成为问题,但可能会导致大型内存映射 HuggingFace 数据集的失败。此外,HuggingFace DatasetDictIterableDatasetDict 对象不受支持。

import ray.data
from datasets import load_dataset

hf_ds = load_dataset("wikitext", "wikitext-2-raw-v1")
ray_ds = ray.data.from_huggingface(hf_ds["train"])
ray_ds.take(2)
[{'text': ''}, {'text': ' = Valkyria Chronicles III = \n'}]

要将 PyTorch 数据集转换为 Ray 数据集,请调用 from_torch()

import ray
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

tds = datasets.CIFAR10(root="data", train=True, download=True, transform=ToTensor())
ds = ray.data.from_torch(tds)

print(ds)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|███████████████████████| 170498071/170498071 [00:07<00:00, 23494838.54it/s]
Extracting data/cifar-10-python.tar.gz to data
Dataset(num_rows=50000, schema={item: object})

要将 TensorFlow 数据集转换为 Ray 数据集,请调用 from_tf()

警告

from_tf 不支持并行读取。仅在处理像MNIST或CIFAR这样的小数据集时使用此函数。

import ray
import tensorflow_datasets as tfds

tf_ds, _ = tfds.load("cifar10", split=["train", "test"])
ds = ray.data.from_tf(tf_ds)

print(ds)
MaterializedDataset(
   num_blocks=...,
   num_rows=50000,
   schema={
      id: binary,
      image: numpy.ndarray(shape=(32, 32, 3), dtype=uint8),
      label: int64
   }
)

读取数据库#

Ray Data 从 MySQL、PostgreSQL、MongoDB 和 BigQuery 等数据库中读取数据。

读取 SQL 数据库#

调用 read_sql() 从提供 Python DB API2-compliant 连接器的数据库中读取数据。

要从 MySQL 读取数据,请安装 MySQL Connector/Python。这是官方的 MySQL 数据库连接器。

pip install mysql-connector-python

然后,定义你的连接逻辑并查询数据库。

import mysql.connector

import ray

def create_connection():
    return mysql.connector.connect(
        user="admin",
        password=...,
        host="example-mysql-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        connection_timeout=30,
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 PostgreSQL 读取数据,请安装 Psycopg 2。它是使用最广泛的 PostgreSQL 数据库连接器。

pip install psycopg2-binary

然后,定义你的连接逻辑并查询数据库。

import psycopg2

import ray

def create_connection():
    return psycopg2.connect(
        user="postgres",
        password=...,
        host="example-postgres-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        dbname="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Snowflake 读取数据,请安装 Snowflake Connector for Python

pip install snowflake-connector-python

然后,定义你的连接逻辑并查询数据库。

import snowflake.connector

import ray

def create_connection():
    return snowflake.connector.connect(
        user=...,
        password=...
        account="ZZKXUVH-IPB52023",
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Databricks 读取数据,请将 DATABRICKS_TOKEN 环境变量设置为您的 Databricks 仓库访问令牌。

export DATABRICKS_TOKEN=...

如果你没有在 Databricks 运行时上运行你的程序,还需要设置 DATABRICKS_HOST 环境变量。

export DATABRICKS_HOST=adb-<workspace-id>.<random-number>.azuredatabricks.net

然后,调用 ray.data.read_databricks_tables() 从 Databricks SQL 仓库读取数据。

import ray

dataset = ray.data.read_databricks_tables(
    warehouse_id='...',  # Databricks SQL warehouse ID
    catalog='catalog_1',  # Unity catalog name
    schema='db_1',  # Schema name
    query="SELECT title, score FROM movie WHERE year >= 1980",
)

要从 BigQuery 读取数据,请安装 Google BigQuery 的 Python 客户端Google BigQueryStorage 的 Python 客户端

pip install google-cloud-bigquery
pip install google-cloud-bigquery-storage

要从 BigQuery 读取数据,请调用 read_bigquery() 并指定项目 ID、数据集和查询(如果适用)。

import ray

# Read the entire dataset. Do not specify query.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    dataset="bigquery-public-data.ml_datasets.iris",
)

# Read from a SQL query of the dataset. Do not specify dataset.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    query = "SELECT * FROM `bigquery-public-data.ml_datasets.iris` LIMIT 50",
)

# Write back to BigQuery
ds.write_bigquery(
    project_id="my_gcloud_project_id",
    dataset="destination_dataset.destination_table",
    overwrite_table=True,
)

读取 MongoDB#

要从 MongoDB 读取数据,调用 read_mongo() 并指定源 URI、数据库和集合。您还需要指定一个要对集合运行的管道。

import ray

# Read a local MongoDB.
ds = ray.data.read_mongo(
    uri="mongodb://localhost:27017",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Reading a remote MongoDB is the same.
ds = ray.data.read_mongo(
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Write back to MongoDB.
ds.write_mongo(
    MongoDatasource(),
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
)

创建合成数据#

合成数据集可用于测试和基准测试。

要从一系列整数创建一个合成的 Dataset,调用 range()。Ray Data 将整数范围存储在单个列中。

import ray

ds = ray.data.range(10000)

print(ds.schema())
Column  Type
------  ----
id      int64

要创建包含数组的合成 Dataset ,请调用 range_tensor() 。Ray Data 将一个整数范围打包成具有提供形状的 ndarray。

import ray

ds = ray.data.range_tensor(10, shape=(64, 64))

print(ds.schema())
Column  Type
------  ----
data    numpy.ndarray(shape=(64, 64), dtype=int64)

加载其他数据源#

如果 Ray Data 无法加载您的数据,请子类化 Datasource。然后,构造您的自定义数据源实例,并将其传递给 read_datasource()。要写入结果,您可能还需要子类化 ray.data.Datasink。然后,创建您的自定义数据接收器实例,并将其传递给 write_datasink()。更多详情,请参阅 高级: 读取和写入自定义文件类型

# Read from a custom datasource.
ds = ray.data.read_datasource(YourCustomDatasource(), **read_args)

# Write to a custom datasink.
ds.write_datasink(YourCustomDatasink())

性能考虑#

默认情况下,所有读取任务的输出块数量是根据输入数据大小和可用资源动态决定的。在大多数情况下,这应该能正常工作。不过,您也可以通过设置 override_num_blocks 参数来覆盖默认值。Ray Data 内部决定运行多少个读取任务以最佳利用集群,范围从 1...override_num_blocks 个任务。换句话说,override_num_blocks 越高,数据集中的数据块越小,从而提供更多的并行执行机会。

有关如何调整输出块的数量以及优化读取性能的其他建议的更多信息,请参阅 优化读取