jax.experimental.sparse.grad#
- jax.experimental.sparse.grad(fun, argnums=0, has_aux=False, **kwargs)[源代码][源代码]#
稀疏感知版本的
jax.grad()
参数和返回值与
jax.grad()
相同,但在相对于jax.experimental.sparse
数组求梯度时,梯度是在由数组的稀疏模式定义的子空间中计算的。示例
>>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) >>> y = jnp.ones(6) >>> sparse.grad(lambda X, y: X @ y)(X, y) BCOO(float32[6], nse=5)