jax.sharding 模块

目录

jax.sharding 模块#

#

class jax.sharding.Sharding#

描述了 jax.Array 如何在设备上布局。

property addressable_devices: set[Device]#

Sharding 中,当前进程可寻址的设备集合。

addressable_devices_indices_map(global_shape)[源代码][源代码]#

从可寻址设备到每个设备包含的数组数据片段的映射。

addressable_devices_indices_map 包含 device_indices_map 中适用于可寻址设备的部分。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index | None]

property device_set: set[Device][源代码]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

devices_indices_map(global_shape)[源代码][源代码]#

返回一个从设备到每个设备包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index]

is_equivalent_to(other, ndim)[源代码][源代码]#

如果两个分片是等价的,则返回 True

如果两个分片将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。

例如,如果 NamedShardingPositionalSharding 都将数组的相同分片放置在相同的设备上,那么它们可能是等价的。

参数:
返回类型:

bool

property is_fully_addressable: bool[源代码]#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool[源代码]#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property num_devices: int[源代码]#

分片包含的设备数量。

shard_shape(global_shape)[源代码][源代码]#

返回每个设备上数据的形状。

此函数返回的分片形状是根据 global_shape 和分片属性计算得出的。

参数:

global_shape (Shape)

返回类型:

Shape

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

Sharding

class jax.sharding.SingleDeviceSharding#

基类:Sharding

一个将数据放置在单个设备上的 分片

参数:

device – 一个单独的 设备

示例

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: set[Device][源代码]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

devices_indices_map(global_shape)[源代码][源代码]#

返回一个从设备到每个设备包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index]

property is_fully_addressable: bool[源代码]#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool[源代码]#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property num_devices: int[源代码]#

分片包含的设备数量。

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

SingleDeviceSharding

class jax.sharding.NamedSharding#

基类:Sharding

一个 NamedSharding 使用命名轴来表达分片。

一个 NamedSharding 是一对 Mesh 设备和 PartitionSpec ,它描述了如何在该网格上对数组进行分片。

一个 Mesh 是一个多维的 JAX 设备 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x''y'

一个 PartitionSpec 是一个元组,其元素可以是 None 、一个网格轴,或一个网格轴的元组。每个元素描述了一个输入维度如何在零个或多个网格维度上进行分区。例如,PartitionSpec('x', 'y') 表示数据的第一维度在网格的 x 轴上分片,第二维度在网格的 y 轴上分片。

分布式数组和自动并行化(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程提供了更多细节和图表,解释了如何使用 MeshPartitionSpec

参数:

示例

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device][源代码]#

Sharding 中,当前进程可寻址的设备集合。

property device_set: set[Device][源代码]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool[源代码]#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property mesh#

(self) -> object

property num_devices: int[源代码]#

分片包含的设备数量。

property spec#

(self) -> object

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

NamedSharding

class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[源代码][源代码]#

基类:Sharding

参数:
  • devices (Sequence[xc.Device] | np.ndarray)

  • memory_kind (str | None)

property device_set: set[xc.Device]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property num_devices: int[源代码]#

分片包含的设备数量。

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

PositionalSharding

class jax.sharding.PmapSharding#

基类:Sharding

描述了由 jax.pmap() 使用的分片。

classmethod default(shape, sharded_dim=0, devices=None)[源代码][源代码]#

创建一个 PmapSharding ,它匹配 jax.pmap() 使用的默认放置。

参数:
  • shape (Shape) – 输入数组的形状。

  • sharded_dim (int) – 输入数组分片的维度。默认为 0。

  • devices (Sequence[xc.Device] | None) – 可选的设备使用顺序。如果省略,则使用隐式

  • used (device order used by pmap is) – jax.local_devices().

  • of (which is the order) – jax.local_devices().

返回类型:

PmapSharding

property device_set: set[Device]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

property devices#

(self) -> ndarray

devices_indices_map(global_shape)[源代码][源代码]#

返回一个从设备到每个设备包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index]

is_equivalent_to(other, ndim)[源代码][源代码]#

如果两个分片是等价的,则返回 True

如果两个分片将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。

例如,如果 NamedShardingPositionalSharding 都将数组的相同分片放置在相同的设备上,那么它们可能是等价的。

参数:
返回类型:

bool

property is_fully_addressable: bool#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property num_devices: int[源代码]#

分片包含的设备数量。

shard_shape(global_shape)[源代码][源代码]#

返回每个设备上数据的形状。

此函数返回的分片形状是根据 global_shape 和分片属性计算得出的。

参数:

global_shape (Shape)

返回类型:

Shape

property sharding_spec#

(self) -> jax::ShardingSpec

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

class jax.sharding.GSPMDSharding#

基类:Sharding

property device_set: set[Device]#

这个 分片 所涵盖的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool#

这种分片是完全可寻址的吗?

如果当前进程可以访问 Sharding 中命名的所有设备,则分片是完全可寻址的。is_fully_addressable 等同于多进程 JAX 中的 “is_local”。

property is_fully_replicated: bool#

这种分片是完全复制的吗?

如果每个设备都有整个数据的完整副本,那么分片就是完全复制的。

property memory_kind: str | None[源代码]#

返回分片的内存类型。

property num_devices: int[源代码]#

分片包含的设备数量。

with_memory_kind(kind)[源代码][源代码]#

返回一个具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

GSPMDSharding

class jax.sharding.PartitionSpec(*partitions)[源代码][源代码]#

描述如何在设备网格上划分数组的元组。

每个元素要么是 None ,要么是一个字符串,或者是一个字符串的元组。更多详情请参阅 jax.sharding.NamedSharding 的文档。

此类存在是为了让 JAX 的 pytree 工具能够区分分区规范与应视为 pytrees 的元组。

class jax.sharding.Mesh(devices, axis_names)[源代码][源代码]#

声明在此管理器范围内可用的硬件资源。

特别是,所有 axis_names 在管理块内成为有效的资源名称,并且可以在 jax.experimental.pjit.pjit()in_axis_resources 参数中使用。另请参阅 JAX 的多进程编程模型(https://jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

如果你在多线程中编译,请确保 with Mesh 上下文管理器位于线程将要执行的函数内部。

参数:
  • devices (np.ndarray) – 包含 JAX 设备对象的 NumPy ndarray 对象(例如从 jax.devices() 获得)。

  • axis_names (tuple[MeshAxisName, ...]) – 要分配给 devices 参数维度的资源轴名称序列。其长度应与 devices 的秩匹配。

示例

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)