jax.experimental.sparse.BCOO#
- class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)[源代码][源代码]#
在 JAX 中实现的实验性批处理 COO 矩阵
- 参数:
- data#
形状为
[*batch_dims, nse, *dense_dims]
的 ndarray,包含稀疏矩阵中显式存储的数据。- 类型:
数组
- indices#
形状为
[*batch_dims, nse, n_sparse]
的 ndarray,包含显式存储数据的索引。重复条目将被求和。- 类型:
数组
示例
从密集数组创建稀疏数组:
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]]) >>> M_sp = BCOO.fromdense(M) >>> M_sp BCOO(float32[2, 3], nse=3)
检查内部表示:
>>> M_sp.data Array([2., 1., 4.], dtype=float32) >>> M_sp.indices Array([[0, 1], [1, 0], [1, 2]], dtype=int32)
从稀疏数组创建密集数组:
>>> M_sp.todense() Array([[0., 2., 0.], [1., 0., 4.]], dtype=float32)
从COO数据和索引创建稀疏数组:
>>> data = jnp.array([1., 3., 5.]) >>> indices = jnp.array([[0, 0], ... [1, 1], ... [2, 2]]) >>> mat = BCOO((data, indices), shape=(3, 3)) >>> mat BCOO(float32[3, 3], nse=3) >>> mat.todense() Array([[1., 0., 0.], [0., 3., 0.], [0., 0., 5.]], dtype=float32)
方法
__init__
(args, *, shape[, indices_sorted, ...])astype
(*args, **kwargs)复制数组并转换为指定的数据类型。
block_until_ready
()from_scipy_sparse
(mat, *[, index_dtype, ...])从
scipy.sparse
数组创建一个 BCOO 数组。fromdense
(mat, *[, nse, index_dtype, ...])从密集的
Array
创建一个 BCOO 数组。reshape
(*args, **kwargs)返回一个包含相同数据但形状不同的新数组。
sort_indices
()返回按索引排序的矩阵副本。
sum
(*args, **kwargs)沿轴求和数组。
sum_duplicates
([nse, remove_zeros])返回一个数组副本,其中重复的索引值被求和。
todense
()创建数组的密集版本。
transpose
([axes])创建一个包含转置的新数组。
tree_flatten
()tree_unflatten
(aux_data, children)update_layout
(*[, n_batch, n_dense, ...])更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。
属性