jax.lax.conv_维度_编号

jax.lax.conv_维度_编号#

jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[源代码][源代码]#

将卷积 dimension_numbers 转换为 ConvDimensionNumbers

参数:
  • lhs_shape – 非负整数的元组,卷积输入的形状。

  • rhs_shape – 非负整数的元组,卷积核的形状。

  • dimension_numbers – None 或一个字符串元组/列表,或一个遵循 xla_client.py 中卷积维度编号规范格式的 ConvDimensionNumbers 对象。

返回:

一个 ConvDimensionNumbers 对象,表示在 lax 函数中使用的规范形式的 dimension_numbers

返回类型:

ConvDimensionNumbers