jax.extend.linear_util.transformation_with_aux

jax.extend.linear_util.transformation_with_aux#

jax.extend.linear_util.transformation_with_aux = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[源代码][源代码]#

向 WrappedFun 添加一个带有辅助输出的额外变换。

参数:
返回类型:

tuple[WrappedFun, Callable[[], Any]]