DTypePolicy
classkeras.mixed_precision.DTypePolicy(name=None)
一个Keras层的dtype策略.
一个dtype策略决定了层的计算和变量dtype.每个层都有一个策略.策略可以通过层的dtype
参数传递给层构造函数,或者可以通过keras.config.set_dtype_policy
设置全局策略.
参数:
name: 策略名称,它决定了计算和变量dtype.可以是任何dtype名称,例如"float32"
或"float64"
,这将导致计算和变量dtype都是该dtype.也可以是字符串"mixed_float16"
或"mixed_bfloat16"
,这将导致计算dtype为float16
或bfloat16
,而变量dtype为float32
.
通常,只有在使用混合精度时才需要与dtype策略进行交互,混合精度是指使用float16或bfloat16进行计算,使用float32进行变量.这就是为什么术语mixed_precision
出现在API名称中.可以通过将"mixed_float16"
或"mixed_bfloat16"
传递给keras.mixed_precision.set_dtype_policy()
来启用混合精度.
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1将自动使用混合精度
<DTypePolicy "mixed_float16">
>>> # 可以选择性地将层覆盖为使用float32而不是混合精度.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # 将策略设置回初始的float32.
>>> keras.config.set_dtype_policy('float32')
在上面的例子中,将dtype="float32"
传递给层等同于传递dtype=keras.config.DTypePolicy("float32")
.一般来说,将一个dtype策略名称传递给层等同于传递相应的策略,因此不需要显式构造一个DTypePolicy
对象.
dtype_policy
functionkeras.mixed_precision.dtype_policy()
返回当前默认的数据类型策略对象.
set_dtype_policy
functionkeras.mixed_precision.set_dtype_policy(policy)
设置默认的dtype策略全局生效.
示例:
>>> keras.config.set_dtype_policy("mixed_float16")