jax.默认_矩阵乘法_精度#
- jax.default_matmul_precision = <jax._src.config.State object>#
用于 jax_default_matmul_precision 配置选项的上下文管理器。
控制32位输入的默认matmul和conv精度。
一些平台,如 TPU,为矩阵乘法和卷积计算提供了可配置的精度级别,以速度换取精度。可以为每个操作控制精度;例如,请参阅
jax.lax.conv_general_dilated()和jax.lax.dot()文档字符串。但当操作未指定特定精度时,控制默认行为也是有用的。此选项可用于控制涉及32位输入的矩阵乘法和卷积计算的默认精度级别。这些级别大致描述了标量积计算的精度。’bfloat16’ 选项是最快且最不精确的;’float32’ 类似于完整的 float32 精度;’tensorfloat32’ 是中间的。
- 参数:
new_val (Any)