jax.dtypes 模块

jax.dtypes 模块#

bfloat16

bfloat16 浮点数值

canonicalize_dtype(dtype[, allow_extended_dtype])

基于 config.x64_enabled 将数据类型转换为规范数据类型。

float0

与标量类型和同名dtype相对应的DType类。

issubdtype(a, b)

如果第一个参数在类型层次结构中低于或等于第二个参数,则返回 True。

prng_key()

PRNG 键数据类型的标量类。

result_type(*args[, return_weak_type_flag])

方便的函数,用于应用 JAX 参数的类型提升。

scalar_type_of(x)

返回与 JAX 值关联的标量类型。