jax.lax.stop_梯度

jax.lax.stop_梯度#

jax.lax.stop_gradient(x)[源代码][源代码]#

停止梯度计算。

在操作上,stop_gradient 是一个恒等函数,即它返回参数 x 不变。然而,stop_gradient 在正向或反向模式自动微分时阻止梯度的流动。如果有多个嵌套的梯度计算,stop_gradient 会阻止所有这些计算的梯度。

例如:

>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
参数:

x (T)

返回类型:

T