jax.experimental.checkify.check_错误

目录

jax.experimental.checkify.check_错误#

jax.experimental.checkify.check_error(error)[源代码][源代码]#

如果 error 表示失败,则引发异常。由 checkify() 功能化。

此函数的语义等同于:

>>> def check_error(err: Error) -> None:
...   err.throw()  # can raise ValueError

但与该实现不同,check_error 可以使用 checkify() 变换进行功能化。

此函数类似于 check() ,但具有不同的签名:虽然 check() 接受一个布尔谓词和一个新的错误消息字符串作为参数,但此函数接受一个 Error 值作为参数。两者 check() 和此函数在失败时都会引发 Python 异常(副作用),因此不能被 jit()pmap()scan() 等函数分阶段执行。两者也可以通过使用 checkify() 进行功能化。

但与 check() 不同,这个函数类似于 checkify() 的直接反向操作:虽然 checkify() 接受一个可能引发 Python 异常的函数作为输入,并生成一个不会产生该效果但会输出 Error 值的新函数,但这个 check_error 函数可以接受 Error 值作为输入,并可能产生引发异常的副作用。也就是说,当 checkify() 从可函数化的异常效果转向错误值时,这个 check_error 则从错误值转向可函数化的异常效果。

check_error 在你想将由 Error 值表示的检查(通过 checkify() 功能化 checks 生成)转换回 Python 异常时非常有用。

参数:

error (Error) – 要检查的错误。

返回类型:

None

例如,你可能希望通过 checkify 功能化程序的一部分,通过 jit() 阶段输出你的功能化代码,然后在 jit() 外部重新注入你的错误值:

>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "must be positive!")
...   return x
>>> def with_inner_jit(x):
...   checked_f = checkify.checkify(f)
...   # a checkified function can be jitted
...   error, out = jax.jit(checked_f)(x)
...   checkify.check_error(error)
...   return out
>>> _ = with_inner_jit(1)  # no failed check
>>> with_inner_jit(-1)  
Traceback (most recent call last):
  ...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)