jax.experimental.key_reuse 模块

jax.experimental.key_reuse 模块#

实验性密钥重用检查#

此模块包含用于检测 JAX 程序中随机键重用的 实验性 功能。它正在积极开发中,这里的 API 可能会发生变化。以下用法需要 JAX 版本 0.4.26 或更新版本。

可以通过 jax_debug_key_reuse 配置启用密钥重用检查。这可以通过以下方式全局设置:

>>> jax.config.update('jax_debug_key_reuse', True)  

或者可以通过 jax.debug_key_reuse() 上下文管理器在本地启用。启用后,使用相同的键两次将导致 KeyReuseError:

>>> import jax
>>> with jax.debug_key_reuse(True):
...   key = jax.random.key(0)
...   val1 = jax.random.normal(key)
...   val2 = jax.random.normal(key)  
Traceback (most recent call last):
 ...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

密钥重用检查器目前是实验性的,但在未来我们可能会默认启用它。