jax.random.categorical

目录

jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None)[源代码][源代码]#

从分类分布中采样随机值。

参数:
  • key (KeyArrayLike) – 一个用作随机密钥的 PRNG 密钥。

  • logits (RealArray) – 分类分布(s)的未归一化对数概率,以便 softmax(logits, axis) 给出相应的概率。

  • axis (int) – 沿此轴的logits属于同一类别分布。

  • shape (Shape | None) – 可选的,一个非负整数元组,表示结果的形状。必须与 np.delete(logits.shape, axis) 广播兼容。默认值(None)生成一个结果形状等于 np.delete(logits.shape, axis)

返回:

一个具有 int 数据类型和形状的随机数组,形状由 shape 给出,如果 shape 不是 None,否则为 np.delete(logits.shape, axis)

返回类型:

Array