jax.extend.linear_util.transformation_with_aux#
- jax.extend.linear_util.transformation_with_aux = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[源代码][源代码]#
向 WrappedFun 添加一个带有辅助输出的额外变换。
- 参数:
fun (WrappedFun)
use_eq_store (bool)
- 返回类型:
tuple[WrappedFun, Callable[[], Any]]