jax.nn 模块

jax.nn 模块#

神经网络库的常用功能。

激活函数#

relu

修正线性单元激活函数。

relu6

校正线性单元6激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

稀疏加函数。

sparse_sigmoid(x)

稀疏的Sigmoid激活函数。

soft_sign(x)

软符号激活函数。

silu(x)

SiLU(又称 swish)激活函数。

swish(x)

SiLU(又称 swish)激活函数。

log_sigmoid(x)

Log-sigmoid 激活函数。

leaky_relu(x[, negative_slope])

泄漏整流线性单元激活函数。

hard_sigmoid(x)

硬Sigmoid激活函数。

hard_silu(x)

硬 SiLU(swish)激活函数

hard_swish(x)

硬 SiLU(swish)激活函数

hard_tanh(x)

\(\mathrm{tanh}\) 激活函数。

elu(x[, alpha])

指数线性单元激活函数。

celu(x[, alpha])

连续可微的指数线性单元激活函数。

selu(x)

缩放指数线性单元激活。

gelu(x[, approximate])

高斯误差线性单元激活函数。

glu(x[, axis])

门控线性单元激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

其他功能#

softmax(x[, axis, where, initial])

Softmax 函数。

log_softmax(x[, axis, where, initial])

Log-Softmax 函数。

logsumexp()

对数-和-指数缩减。

standardize(x[, axis, mean, variance, ...])

通过减去 mean 并除以 \(\sqrt{\mathrm{variance}}\) 来标准化数组。

one_hot(x, num_classes, *[, dtype, axis])

对给定的索引进行独热编码。

dot_product_attention(query, key, value[, ...])

缩放点积注意力函数。