jax.experimental.checkify.checkify#
- jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[源代码][源代码]#
在 fun 中功能化 check 调用,并可选择添加运行时错误检查。
运行时错误要么是用户添加的
check()断言,要么是根据errors参数自动添加的检查,如 NaN 检查。返回的函数将返回一个错误对象 err 以及原始函数的输出。
err.get()将返回None``(如果没有发生错误)或包含错误信息的字符串。此错误信息将对应于发生的第一个错误。``err.throw()如果发生错误,将引发带有错误信息的 ValueError。默认情况下,仅启用用户添加的
check()断言。您可以通过errors参数启用自动检查。- 可以启用的自动检查集,以及何时生成错误:
user_checks: 一个check()评估为 False。nan_checks: 浮点运算生成了一个 NaN 值作为输出。div_checks: 除以零。index_checks: 索引超出边界。
可以通过传递一个错误 集合`(例如 ``errors=nan_checks`)来同时启用多个类别。多个集合可以重新组合(例如
errors=float_checks|user_checks)- 参数:
- 返回:
一个接受与
fun相同参数的函数,并返回一个包含两个元素的输出:第一个元素是一个Error值,表示第一个失败的check(),第二个元素是fun的原始输出。- 返回类型:
例如:
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin