jax.lax.自定义线性求解

jax.lax.自定义线性求解#

jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[源代码][源代码]#

使用隐式定义的梯度执行无矩阵线性求解。

此函数允许通过在解处直接进行隐式微分来覆盖或定义线性求解的梯度,而不是通过求解操作进行微分。这有时可以更快或更数值稳定,或者求解操作的微分可能根本未实现(例如,如果 solve 使用 lax.while_loop)。

必需的不变量:

x = solve(matvec, b)  # solve the linear equation
assert matvec(x) == b  # not checked
参数:
  • matvec – 线性函数以进行反转。必须是可微分的。

  • b – 方程的常数右侧。可以是数组的任何嵌套结构。

  • solve – 解决线性方程的高级函数,即对于所有与 b 形式相同的 x,满足 solve(matvec, x) == x。此函数不需要可微。

  • transpose_solve – 用于求解转置线性方程的高级函数,即 transpose_solve(vecmat, x) == x,其中 vecmat 是线性映射 matvec 的转置(通过自动微分自动计算)。除非 symmetric=True,否则这是反向模式自动微分所必需的,在这种情况下,solve 提供默认值。

  • symmetric – 布尔值,指示是否可以安全地假设线性映射对应于一个对称矩阵,即 matvec == vecmat

  • has_aux – 布尔值,指示 solvetranspose_solve 函数是否将辅助数据(如求解器诊断信息)作为第二个参数返回。

返回:

solve(matvec, b) 的结果,假设解 x 满足线性方程 matvec(x) == b 的情况下定义了梯度。