jax.random.choice#
- jax.random.choice(key, a, shape=(), replace=True, p=None, axis=0)[源代码][源代码]#
从给定的数组中生成一个随机样本。
警告
如果
p
的非零元素数量少于shape
中指定的请求样本数量,并且replace=False
,则此函数的输出未定义。请确保使用适当的输入。- 参数:
key (KeyArrayLike) – 一个用作随机密钥的 PRNG 密钥。
a (int | ArrayLike) – 数组或整数。如果是一个ndarray,则从其元素中生成一个随机样本。如果是一个整数,则生成的随机样本类似于从arange(a)中生成。
shape (Shape) – tuple of ints, 可选. 输出形状。如果给定的形状是, 例如
(m, n)
, 那么会抽取m * n
个样本。默认是 (), 在这种情况下会返回一个单一值。replace (bool) – boolean. 样本是有放回还是无放回。默认为 True。
p (RealArray | None) – 1-D 类数组, 与 a 中每个条目相关的概率。如果没有给出,样本假设 a 中所有条目的均匀分布。
axis (int) – int, 可选。执行选择的轴。默认值为 0,按行选择。
- 返回:
一个形状为 shape 的数组,包含来自 a 的样本。
- 返回类型: