SeedGenerator
classkeras.random.SeedGenerator(seed=None, name=None, **kwargs)
生成每次调用使用RNG的函数时变化的变量种子.
在Keras中,所有使用RNG的方法(如keras.random.normal()
)都是无状态的,这意味着如果你传递一个整数种子给他们(如seed=42
),他们每次调用都会返回相同的值.为了在每次调用时获得不同的值,你必须使用SeedGenerator
作为种子参数.SeedGenerator
对象是有状态的.
示例:
seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)
在层中的使用:
class Dropout(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=False):
if training:
return keras.random.dropout(
x, rate=0.5, seed=self.seed_generator
)
return x