jax.numpy.linspace#
- jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) Array [源代码][源代码]#
- jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: bool, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) tuple[Array, Array]
- jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0, device: xc.Device | Sharding | None = None) tuple[Array, Array]
- jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) Array | tuple[Array, Array]
返回区间内的等间隔数字。
JAX 实现的
numpy.linspace()
。- 参数:
- 返回:
values
是一个从start
到stop
均匀间隔的值数组step
是相邻值之间的间隔。
- 返回类型:
一个数组
values
,或者如果retstep
为 True,则为一个元组(values, step)
,其中
参见
jax.numpy.arange()
: 生成N
个均匀间隔的值,给定起始点和步长jax.numpy.logspace()
: 生成对数间隔的值。jax.numpy.geomspace()
: 生成几何间隔的值。
示例
0 到 10 之间的 5 个值的列表:
>>> jnp.linspace(0, 10, 5) Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
在0到10之间,不包括端点的8个值的列表:
>>> jnp.linspace(0, 10, 8, endpoint=False) Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
值列表及其之间的步长
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True) >>> vals Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) >>> step Array(1.25, dtype=float32)
多维 linspace:
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 10]) >>> jnp.linspace(start, stop, 5) Array([[ 0. , 5. ], [ 1.25, 6.25], [ 2.5 , 7.5 ], [ 3.75, 8.75], [ 5. , 10. ]], dtype=float32)