LayoutMap
classkeras.distribution.LayoutMap(device_mesh)
一个类似字典的对象,映射字符串到 TensorLayout
实例.
LayoutMap
使用字符串作为键,TensorLayout
作为值.在普通 Python 字典和此类之间存在行为差异.字符串键在检索值时将被视为正则表达式.有关更多详细信息,请参阅 get
的文档字符串.
请参见下面的使用示例.您可以定义 TensorLayout
的命名模式,然后检索相应的 TensorLayout
实例.
在正常情况下,查询的键通常是 variable.path
,这是变量的标识符.
作为快捷方式,插入时也允许使用轴名称的元组或列表作为值,并将转换为 TensorLayout
.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None
参数:
device_mesh: keras.distribution.DeviceMesh
实例.
DeviceMesh
classkeras.distribution.DeviceMesh(shape, axis_names, devices=None)
分布式计算的计算设备集群.
此API与jax.sharding.Mesh
和tf.dtensor.Mesh
对齐,表示全局上下文中的计算设备.
更多详情请参见jax.sharding.Mesh 和tf.dtensor.Mesh.
参数:
shape: 整数元组或列表.整体DeviceMesh
的形状,例如(8,)
表示仅数据并行的分布,
或(4, 2)
表示模型+数据并行的分布.
axis_names: 字符串列表.每个轴的逻辑名称,用于DeviceMesh
.axis_names
的长度应与shape
的秩匹配.
axis_names
将用于匹配/创建TensorLayout
,以便在分发数据和变量时使用.
devices: 可选的设备列表.默认为从keras.distribution.list_devices()
获取的本地所有可用设备.
TensorLayout
classkeras.distribution.TensorLayout(axes, device_mesh=None)
用于应用于张量的布局.
此API与jax.sharding.NamedSharding
和tf.dtensor.Layout
对齐.
更多详细信息请参见jax.sharding.NamedSharding 和tf.dtensor.Layout.
参数:
axes: 应映射到DeviceMesh
中axis_names
的元组字符串.对于任何不需要分片的维度,
可以使用None
作为占位符.
device_mesh: 可选的DeviceMesh
,将用于创建布局.张量到物理设备的实际映射
在指定网格之前是未知的.
distribute_tensor
functionkeras.distribution.distribute_tensor(tensor, layout)
改变jit函数执行中Tensor值的布局.
参数:
tensor: 要改变布局的Tensor.
layout: 要应用于值的TensorLayout
.
返回: 具有指定tensor布局的新值.