jax.lax.stop_梯度#
- jax.lax.stop_gradient(x)[源代码][源代码]#
停止梯度计算。
在操作上,
stop_gradient
是一个恒等函数,即它返回参数 x 不变。然而,stop_gradient
在正向或反向模式自动微分时阻止梯度的流动。如果有多个嵌套的梯度计算,stop_gradient
会阻止所有这些计算的梯度。例如:
>>> jax.grad(lambda x: x**2)(3.) Array(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) Array(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) Array(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) Array(0., dtype=float32, weak_type=True)
- 参数:
x (T)
- 返回类型:
T