jax.export.DisabledSafetyCheck

jax.export.DisabledSafetyCheck#

class jax.export.DisabledSafetyCheck(_impl)[源代码][源代码]#

在(反)序列化时应跳过的安全检查。

这些检查大部分在序列化时执行,但有些会延迟到反序列化时进行。禁用的检查列表会附加到序列化中,例如,作为 jax.export.Exportedtf.XlaCallModuleOp 的字符串属性序列。

在使用 jax2tf 时,你可以通过传递 TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform 来禁用更多的反序列化安全检查。

参数:

_impl (str)

__init__(_impl)[源代码][源代码]#
参数:

_impl (str)

方法

__init__(_impl)

custom_call(target_name)

允许序列化一个未知是否稳定的调用目标。

is_custom_call()

返回此指令允许的自定义调用目标。

platform()

允许编译平台与导出平台不同。

shape_assertions()

已弃用:无操作。