jax.nn.one_hot

目录

jax.nn.one_hot#

jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[源代码][源代码]#

对给定的索引进行独热编码。

输入 x 中的每个索引都被编码为一个长度为 num_classes 的零向量,其中在 index 位置的元素设置为1:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

超出范围 [0, num_classes) 的索引将被编码为零:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
参数:
  • x (Any) – 一个索引张量。

  • num_classes (int) – 独热维度中的类别数量。

  • dtype (Any) – optional, a float dtype for the returned values (default jnp.float_).

  • axis (int | AxisName) – 函数应计算的轴或轴。

返回类型:

Array