Keras 3 API 文档 / 混合精度 / 混合精度策略 API

混合精度策略 API

[source]

DTypePolicy class

keras.mixed_precision.DTypePolicy(name=None)

一个Keras层的dtype策略.

一个dtype策略决定了层的计算和变量dtype.每个层都有一个策略.策略可以通过层的dtype参数传递给层构造函数,或者可以通过keras.config.set_dtype_policy设置全局策略.

参数: name: 策略名称,它决定了计算和变量dtype.可以是任何dtype名称,例如"float32""float64",这将导致计算和变量dtype都是该dtype.也可以是字符串"mixed_float16""mixed_bfloat16",这将导致计算dtype为float16bfloat16,而变量dtype为float32.

通常,只有在使用混合精度时才需要与dtype策略进行交互,混合精度是指使用float16或bfloat16进行计算,使用float32进行变量.这就是为什么术语mixed_precision出现在API名称中.可以通过将"mixed_float16""mixed_bfloat16"传递给keras.mixed_precision.set_dtype_policy()来启用混合精度.

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1将自动使用混合精度
<DTypePolicy "mixed_float16">
>>> # 可以选择性地将层覆盖为使用float32而不是混合精度.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # 将策略设置回初始的float32.
>>> keras.config.set_dtype_policy('float32')

在上面的例子中,将dtype="float32"传递给层等同于传递dtype=keras.config.DTypePolicy("float32").一般来说,将一个dtype策略名称传递给层等同于传递相应的策略,因此不需要显式构造一个DTypePolicy对象.


[source]

dtype_policy function

keras.mixed_precision.dtype_policy()

返回当前默认的数据类型策略对象.


[source]

set_dtype_policy function

keras.mixed_precision.set_dtype_policy(policy)

设置默认的dtype策略全局生效.

示例:

>>> keras.config.set_dtype_policy("mixed_float16")