Keras 3 API 文档 / 多设备分布 / LayoutMap API

LayoutMap API

[source]

LayoutMap class

keras.distribution.LayoutMap(device_mesh)

一个类似字典的对象,映射字符串到 TensorLayout 实例.

LayoutMap 使用字符串作为键,TensorLayout 作为值.在普通 Python 字典和此类之间存在行为差异.字符串键在检索值时将被视为正则表达式.有关更多详细信息,请参阅 get 的文档字符串.

请参见下面的使用示例.您可以定义 TensorLayout 的命名模式,然后检索相应的 TensorLayout 实例.

在正常情况下,查询的键通常是 variable.path,这是变量的标识符.

作为快捷方式,插入时也允许使用轴名称的元组或列表作为值,并将转换为 TensorLayout.

layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None

参数: device_mesh: keras.distribution.DeviceMesh 实例.


[source]

DeviceMesh class

keras.distribution.DeviceMesh(shape, axis_names, devices=None)

分布式计算的计算设备集群.

此API与jax.sharding.Meshtf.dtensor.Mesh对齐,表示全局上下文中的计算设备.

更多详情请参见jax.sharding.Meshtf.dtensor.Mesh.

参数: shape: 整数元组或列表.整体DeviceMesh的形状,例如(8,)表示仅数据并行的分布, 或(4, 2)表示模型+数据并行的分布. axis_names: 字符串列表.每个轴的逻辑名称,用于DeviceMesh.axis_names的长度应与shape的秩匹配. axis_names将用于匹配/创建TensorLayout,以便在分发数据和变量时使用. devices: 可选的设备列表.默认为从keras.distribution.list_devices()获取的本地所有可用设备.


[source]

TensorLayout class

keras.distribution.TensorLayout(axes, device_mesh=None)

用于应用于张量的布局.

此API与jax.sharding.NamedShardingtf.dtensor.Layout对齐.

更多详细信息请参见jax.sharding.NamedShardingtf.dtensor.Layout.

参数: axes: 应映射到DeviceMeshaxis_names的元组字符串.对于任何不需要分片的维度, 可以使用None作为占位符. device_mesh: 可选的DeviceMesh,将用于创建布局.张量到物理设备的实际映射 在指定网格之前是未知的.


[source]

distribute_tensor function

keras.distribution.distribute_tensor(tensor, layout)

改变jit函数执行中Tensor值的布局.

参数: tensor: 要改变布局的Tensor. layout: 要应用于值的TensorLayout.

返回: 具有指定tensor布局的新值.