jax.experimental.key_reuse
模块#
实验性密钥重用检查#
此模块包含用于检测 JAX 程序中随机键重用的 实验性 功能。它正在积极开发中,这里的 API 可能会发生变化。以下用法需要 JAX 版本 0.4.26 或更新版本。
可以通过 jax_debug_key_reuse
配置启用密钥重用检查。这可以通过以下方式全局设置:
>>> jax.config.update('jax_debug_key_reuse', True)
或者可以通过 jax.debug_key_reuse()
上下文管理器在本地启用。启用后,使用相同的键两次将导致 KeyReuseError
:
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key)
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
密钥重用检查器目前是实验性的,但在未来我们可能会默认启用它。