jax.sharding
模块#
类#
- class jax.sharding.Sharding#
描述了
jax.Array
如何在设备上布局。- addressable_devices_indices_map(global_shape)[源代码][源代码]#
从可寻址设备到每个设备包含的数组数据片段的映射。
addressable_devices_indices_map
包含device_indices_map
中适用于可寻址设备的部分。- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index | None]
- devices_indices_map(global_shape)[源代码][源代码]#
返回一个从设备到每个设备包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[源代码][源代码]#
如果两个分片是等价的,则返回
True
。如果两个分片将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。
例如,如果
NamedSharding
和PositionalSharding
都将数组的相同分片放置在相同的设备上,那么它们可能是等价的。
- property is_fully_addressable: bool[源代码]#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- class jax.sharding.SingleDeviceSharding#
基类:
Sharding
一个将数据放置在单个设备上的
分片
。- 参数:
device – 一个单独的
设备
。
示例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- devices_indices_map(global_shape)[源代码][源代码]#
返回一个从设备到每个设备包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- property is_fully_addressable: bool[源代码]#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- class jax.sharding.NamedSharding#
基类:
Sharding
一个
NamedSharding
使用命名轴来表达分片。一个
NamedSharding
是一对Mesh
设备和PartitionSpec
,它描述了如何在该网格上对数组进行分片。一个
Mesh
是一个多维的 JAX 设备 NumPy 数组,其中网格的每个轴都有一个名称,例如'x'
或'y'
。一个
PartitionSpec
是一个元组,其元素可以是None
、一个网格轴,或一个网格轴的元组。每个元素描述了一个输入维度如何在零个或多个网格维度上进行分区。例如,PartitionSpec('x', 'y')
表示数据的第一维度在网格的x
轴上分片,第二维度在网格的y
轴上分片。分布式数组和自动并行化(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程提供了更多细节和图表,解释了如何使用
Mesh
和PartitionSpec
。- 参数:
mesh – 一个
jax.sharding.Mesh
对象。spec – 一个
jax.sharding.PartitionSpec
对象。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property is_fully_addressable: bool[源代码]#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- property mesh#
(self) -> object
- property spec#
(self) -> object
- class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[源代码][源代码]#
基类:
Sharding
- 参数:
devices (Sequence[xc.Device] | np.ndarray)
memory_kind (str | None)
- property is_fully_addressable: bool#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- class jax.sharding.PmapSharding#
基类:
Sharding
描述了由
jax.pmap()
使用的分片。- classmethod default(shape, sharded_dim=0, devices=None)[源代码][源代码]#
创建一个
PmapSharding
,它匹配jax.pmap()
使用的默认放置。- 参数:
shape (Shape) – 输入数组的形状。
sharded_dim (int) – 输入数组分片的维度。默认为 0。
devices (Sequence[xc.Device] | None) – 可选的设备使用顺序。如果省略,则使用隐式
used (device order used by pmap is) –
jax.local_devices()
.of (which is the order) –
jax.local_devices()
.
- 返回类型:
- property devices#
(self) -> ndarray
- devices_indices_map(global_shape)[源代码][源代码]#
返回一个从设备到每个设备包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[源代码][源代码]#
如果两个分片是等价的,则返回
True
。如果两个分片将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。
例如,如果
NamedSharding
和PositionalSharding
都将数组的相同分片放置在相同的设备上,那么它们可能是等价的。- 参数:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- 返回类型:
- property is_fully_addressable: bool#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- shard_shape(global_shape)[源代码][源代码]#
返回每个设备上数据的形状。
此函数返回的分片形状是根据
global_shape
和分片属性计算得出的。- 参数:
global_shape (Shape)
- 返回类型:
Shape
- property sharding_spec#
(self) -> jax::ShardingSpec
- class jax.sharding.GSPMDSharding#
基类:
Sharding
- property is_fully_addressable: bool#
这种分片是完全可寻址的吗?
如果当前进程可以访问
Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
等同于多进程 JAX 中的 “is_local”。
- class jax.sharding.PartitionSpec(*partitions)[源代码][源代码]#
描述如何在设备网格上划分数组的元组。
每个元素要么是
None
,要么是一个字符串,或者是一个字符串的元组。更多详情请参阅jax.sharding.NamedSharding
的文档。此类存在是为了让 JAX 的 pytree 工具能够区分分区规范与应视为 pytrees 的元组。
- class jax.sharding.Mesh(devices, axis_names)[源代码][源代码]#
声明在此管理器范围内可用的硬件资源。
特别是,所有
axis_names
在管理块内成为有效的资源名称,并且可以在jax.experimental.pjit.pjit()
的in_axis_resources
参数中使用。另请参阅 JAX 的多进程编程模型(https://jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)如果你在多线程中编译,请确保
with Mesh
上下文管理器位于线程将要执行的函数内部。- 参数:
devices (np.ndarray) – 包含 JAX 设备对象的 NumPy ndarray 对象(例如从
jax.devices()
获得)。axis_names (tuple[MeshAxisName, ...]) – 要分配给
devices
参数维度的资源轴名称序列。其长度应与devices
的秩匹配。
示例
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)