jax.experimental.sparse.value_and_grad

jax.experimental.sparse.value_and_grad#

jax.experimental.sparse.value_and_grad(fun, argnums=0, has_aux=False, **kwargs)[源代码][源代码]#

稀疏感知版本的 jax.value_and_grad()

参数和返回值与 jax.value_and_grad() 相同,但在相对于 jax.experimental.sparse 数组取梯度时,梯度是在由数组的稀疏模式定义的子空间中计算的。

示例

>>> from jax.experimental import sparse
>>> X = sparse.BCOO.fromdense(jnp.arange(6.))
>>> y = jnp.ones(6)
>>> sparse.value_and_grad(lambda X, y: X @ y)(X, y)
(Array(15., dtype=float32), BCOO(float32[6], nse=5))
参数:
  • fun (Callable)

  • argnums (int | Sequence[int])

返回类型:

Callable[…, tuple[Any, Any]]