更新日志

目录

更新日志#

最佳查看位置 这里。 关于实验性 Pallas API 的具体变更,请参阅 Pallas 更新日志

jax 0.4.33#

这是基于 jax 0.4.32 的补丁发布,修复了该版本中发现的两个错误。

在 JAX 0.4.32 版本中,发现了一个仅在 TPU 上出现的数据损坏错误,该错误仅在同一作业中存在多个 TPU 切片时表现出来,例如在多个 v5e 切片上进行训练时。此版本通过固定 libtpu 的修复版本来解决该问题。

jaxlib 0.4.33#

此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。

jax 0.4.32 (2024年9月11日)#

注意:由于TPU上的数据损坏错误,此版本已从PyPi中撤回。更多详情请参见0.4.33版本发布说明。

  • 新功能

  • 更改

    • jax_enable_memories 标志默认设置为 True

    • jax.numpy 现在支持 Python 数组 API 标准 2023.12 版本。更多信息请参见 Python 数组 API 标准

    • 在更多情况下,CPU 后端的计算现在可以异步分派。以前,非并行计算总是同步分派的。你可以通过设置 jax.config.update('jax_cpu_enable_async_dispatch', False) 来恢复旧的行为。

    • 新增了 jax.process_indices() 函数,以替代在 JAX v0.2.13 中被弃用的 jax.host_ids() 函数。

    • 为了与 numpy.fabs 的行为保持一致,jax.numpy.fabs 已被修改,不再支持 复数数据类型

    • jax.tree_util.register_dataclass 现在检查 data_fieldsmeta_fields 是否包含所有 init=True 的 dataclass 字段,并且仅包含这些字段,如果 nodetype 是一个 dataclass 的话。

    • 几个 jax.numpy 函数现在具有完整的 ufunc 接口,包括 addmultiplybitwise_andbitwise_orbitwise_xorlogical_andlogical_andlogical_and。在未来的版本中,我们计划将这些扩展到其他 ufuncs。

    • 添加了 jax.lax.optimization_barrier(),它允许用户防止编译器优化,如公共子表达式消除,并控制调度。

  • 重大变更

    • MHLO MLIR 方言 (jax.extend.mlir.mhlo) 已被移除。请改用 stablehlo 方言。

  • 弃用

    • 自 JAX v0.4.27 起被弃用后,不再允许向 jax.numpy.clip()jax.numpy.hypot() 输入复杂数据。

    • 已弃用以下API:

      • jax.lib.xla_bridge.xla_client: 直接使用 jax.lib.xla_client

      • jax.lib.xla_bridge.get_backend: 使用 jax.extend.backend.get_backend()

      • jax.lib.xla_bridge.default_backend: 使用 jax.extend.backend.default_backend()

    • jax.experimental.array_api 模块已被弃用,不再需要导入它来使用数组 API。jax.numpy 直接支持数组 API;更多信息请参见 Python 数组 API 标准

    • 内部工具 jax.core.check_eqnjax.core.check_typejax.core.check_valid_jaxtype 现已弃用,并将在未来移除。

    • jax.numpy.round_ 已被弃用,这是由于 NumPy 2.0 中相应 API 的移除。请改用 jax.numpy.round()

    • 传递一个 DLPack 胶囊到 jax.dlpack.from_dlpack() 已被弃用。jax.dlpack.from_dlpack() 的参数应为另一个实现 __dlpack__ 协议的框架中的数组。

jaxlib 0.4.32 (2024年9月11日)#

注意:由于TPU上的数据损坏错误,此版本已从PyPi中撤回。更多详情请参见0.4.33版本发布说明。

  • 重大变更

    • 添加了密封的 CUDA 支持。密封的 CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后在各种 Bazel 目标中使用 CUDA 库和工具作为依赖项。这使得 JAX 及其支持的 CUDA 版本的构建更具可重复性。

  • 更改

    • 添加了 SparseCore 分析。

      • JAX 现在支持在 TPUv5p 芯片上进行 SparseCore 分析。这些追踪记录可以在 Tensorboard Profiler 的 TraceViewer 中查看。

jax 0.4.31 (2024年7月29日)#

  • 删除

    • xmap 已被删除。请使用 shard_map() 作为替代。

  • 更改

    • 最低的 CuDNN 版本是 v9.1。这在之前的版本中也是如此,但我们现在正式声明这一版本约束。

    • 最低 Python 版本现在是 3.10。3.10 将作为最低支持版本保留至 2025 年 7 月。

    • 最低 NumPy 版本现在是 1.24。NumPy 1.24 将作为最低支持版本保留至 2024 年 12 月。

    • 最低的 SciPy 版本现在是 1.10。SciPy 1.10 将作为最低支持版本,直到 2025 年 1 月。

    • jax.numpy.ceil(), jax.numpy.floor()jax.numpy.trunc() 现在返回与输入相同数据类型的输出,即不再将整数或布尔输入向上转换为浮点数。

    • libdevice.10.bc 不再与 CUDA 轮子捆绑在一起。它必须作为本地 CUDA 安装的一部分安装,或者通过 NVIDIA 的 CUDA pip 轮子安装。

    • jax.experimental.pallas.BlockSpec 现在期望 block_shapeindex_map 之前 传递。旧的参数顺序已被弃用,并将在未来的版本中移除。

    • 更新了 GPU 设备的 repr,使其与 TPU/CPU 更加一致。例如,cuda(id=0) 现在将是 CudaDevice(id=0)

    • jax.Array 中添加了 device 属性和 to_device 方法,作为 JAX 对 Array API 支持的一部分。

  • 弃用

    • 移除了许多与多态形状相关的先前已弃用的内部API。从 jax.core 中移除了 canonicalize_shapedimension_as_valuedefinitely_equalsymbolic_equal_dim

    • HLO 降低规则不应再将单例 ir.Values 包装在元组中。相反,应返回未包装的单例 ir.Values。对包装值的支持将在 JAX 的未来版本中移除。

    • jax.experimental.jax2tf.convert()native_serialization=Falseenable_xla=False 的情况下已被弃用,此支持将在未来版本中移除。自 JAX 0.4.16(2023年9月)以来,原生序列化已成为默认设置。

    • 之前已弃用的函数 jax.random.shuffle 已被移除;请改用 jax.random.permutation 并设置 independent=True

jaxlib 0.4.31 (2024年7月29日)#

  • 错误修复

    • 修复了一个错误,该错误意味着 jit 调度快速路径未能正确处理负的 static_argnums。

    • 修复了一个错误,该错误导致对奇异矩阵批次的三角求解产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。

jax 0.4.30 (2024年6月18日)#

  • 更改

    • JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本被提升到 0.4.0,但在本次发布中已回滚,以便 TensorFlow 和 JAX 的用户有更多时间迁移到更新的 TensorFlow 版本。

    • jax.experimental.mesh_utils 现在可以为 TPU v5e 创建一个高效的网格。

    • jax 现在直接依赖于 jaxlib。这一变化是由 CUDA 插件切换实现的:不再有多个 jaxlib 变体。你可以通过 pip install jax 安装仅支持 CPU 的 jax,无需额外配置。

    • 添加了一个用于导出和序列化 JAX 函数的 API。这曾经存在于 jax.experimental.export(即将被弃用),现在将存在于 jax.export。请参阅 文档

  • 弃用

    • 内部美化打印工具 jax.core.pp_* 已被弃用,并将在未来的版本中移除。

    • 跟踪器的哈希处理已被弃用,并且将在未来的 JAX 版本中导致 TypeError。这之前是这种情况,但在最近的几个 JAX 版本中出现了意外的回归。

    • jax.experimental.export 已被弃用。请改用 jax.export。请参阅 迁移指南

    • 在大多数情况下,传递一个数组来代替 dtype 已被弃用;例如,对于数组 xyx.astype(y) 将引发警告。要静默该警告,请使用 x.astype(y.dtype)

    • jax.xla_computation 已被弃用,并将在未来的版本中移除。请使用 AOT API 来获得与 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以替换为 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 你也可以使用 jax.stages.Lowered.out_info 属性来获取输出信息(如树结构、形状和数据类型)。

      • 对于跨后端的降低,你可以将 jax.xla_computation(fn, backend='tpu')(*args, **kwargs) 替换为 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

jaxlib 0.4.30 (2024年6月18日)#

  • 单片 CUDA jaxlibs 的支持已被移除。您必须使用基于插件的安装(pip install jax[cuda12]pip install jax[cuda12_local])。

jax 0.4.29 (2024年6月10日)#

  • 更改

    • 我们预计这将是支持单体 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本。未来的版本将使用 CUDA 插件 jaxlib(例如 pip install jax[cuda12])。

    • JAX 现在需要 ml_dtypes 版本 0.4.0 或更新版本。

    • 移除了对旧版 jax.experimental.export API 的向后兼容支持。不再可能使用 from jax.experimental.export import export,而应改为使用 from jax.experimental import export。移除的功能自 0.4.24 版本起已被弃用。

    • jax.tree.all()jax.tree_util.tree_all() 添加了 is_leaf 参数。

  • 弃用

    • jax.sharding.XLACompatibleSharding 已被弃用。请使用 jax.sharding.Sharding

    • jax.experimental.Exported.in_shardings 已被重命名为 jax.experimental.Exported.in_shardings_hloout_shardings 也是如此。旧名称将在3个月后移除。

    • 移除了许多之前已弃用的API:

      • jax.core 导入:non_negative_dim, DimSize, Shape

      • 来自 jax.laxtie_in

      • 来自 jax.nnnormalize

      • jax.interpreters.xla 模块中:backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXlaOp

    • tol 参数在 jax.numpy.linalg.matrix_rank() 中已被弃用,并即将被移除。请改用 rtol

    • rcond 参数在 jax.numpy.linalg.pinv() 中已被弃用,并即将被移除。请改用 rtol

    • 已弃用的 jax.config 子模块已被移除。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • jax.random API 不再接受批量键,之前有些是不经意间接受的。今后,我们建议在这种情况下显式使用 jax.vmap()

    • jax.scipy.special.beta() 中,xy 参数已重命名为 ab,以与其他 beta API 保持一致。

  • 新功能

    • 添加了 jax.experimental.Exported.in_shardings_jax() 以从存储在 Exported 对象中的 HloShardings 构建可以与 JAX API 一起使用的 shardings。

jaxlib 0.4.29 (2024年6月10日)#

  • 错误修复

    • 修复了一个XLA在某些连接操作中分片不正确的错误,这表现为累积减少操作的输出不正确(#21403)。

    • 修复了一个XLA:CPU错误编译某些矩阵乘法融合的bug(https://github.com/openxla/xla/pull/13301)。

    • 修复了GPU上的编译器崩溃问题(https://github.com/google/jax/issues/21396)。

  • 弃用

    • jax.tree.map(f, None, non-None) 现在会发出一个 DeprecationWarning,并且在未来的 jax 版本中将会引发错误。None 仅是其自身的树前缀。要保留当前行为,您可以通过编写以下代码,让 jax.tree.mapNone 视为叶值:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

jax 0.4.28 (2024年5月9日)#

  • 错误修复

    • 恢复了对 make_jaxpr 的更改,该更改破坏了 Equinox (#21116)。

  • 弃用与移除

  • 更改

    • 此版本的最低 jaxlib 版本为 0.4.27。

jaxlib 0.4.28 (2024年5月9日)#

  • 错误修复

    • 修复了Python 3.10或更早版本中Array和JIT Python对象类型名称中的内存损坏错误。

    • 修复了在CUDA 12.4下出现的警告 '+ptx84' 不是此目标的可识别特性

    • 修复了CPU上的编译缓慢问题。

  • 更改

    • Windows 构建现在使用 Clang 而不是 MSVC 进行构建。

jax 0.4.27 (2024年5月7日)#

  • 新功能

    • 添加了 jax.numpy.unstack()jax.numpy.cumulative_sum(),这些函数是根据即将被NumPy采用的数组API 2023标准添加的。

    • 添加了一个新的配置选项 jax_cpu_collectives_implementation 以选择 CPU 后端使用的跨进程集体操作的实现。可选的选项有 'none'(默认)、'gloo''mpi'(需要 jaxlib 0.4.26)。如果设置为 'none',则跨进程集体操作被禁用。

  • 更改

    • jax.pure_callback(), jax.experimental.io_callback()jax.debug.callback() 现在使用 jax.Array 而不是 np.ndarray。你可以通过在传递给回调之前使用 jax.tree.map(np.asarray, args) 转换参数来恢复旧的行为。

    • complex_arr.astype(bool) 现在遵循与 NumPy 相同的语义,当 complex_arr 等于 0 + 0j 时返回 False,否则返回 True。

    • core.Token 现在是一个非平凡的类,它包装了一个 jax.Array。它可以在计算中创建并传递,以建立依赖关系。单例对象 core.token 已被移除,用户现在应该创建并使用新的 core.Token 对象。

    • 在GPU上,Threefry PRNG 实现默认不再降低为内核调用。这一选择可以在编译时成本下提高运行时内存使用。之前产生内核调用的行为可以通过 jax.config.update('jax_threefry_gpu_kernel_lowering', True) 恢复。如果新的默认设置导致问题,请提交错误报告。否则,我们打算在未来版本中移除此标志。

  • 弃用与移除

    • Pallas 现在专门使用 XLA 来编译 GPU 上的内核。通过 Triton Python API 的旧降级过程已被移除,JAX_TRITON_COMPILE_VIA_XLA 环境变量不再有任何影响。

    • jax.numpy.clip() 有一个新的参数签名:aa_mina_max 已被弃用,取而代之的是 x(仅位置参数)、minmax#20550)。

    • JAX 数组的 device() 方法已被移除,自 JAX v0.4.21 起已被弃用。请改用 arr.devices()

    • initial 参数在 jax.nn.softmax()jax.nn.log_softmax() 中已被弃用;现在无需设置此参数即可支持空输入。

    • jax.jit() 中,传递无效的 static_argnumsstatic_argnames 现在会导致错误而不是警告。

    • 最低的 jaxlib 版本现在是 0.4.23。

    • 当向 jax.numpy.hypot() 函数传递复数值输入时,现在会发出一个弃用警告。当弃用完成时,这将引发一个错误。

    • 传递给 jax.numpy.nonzero()jax.numpy.where() 及相关函数的标量参数现在会引发错误,这与 NumPy 中的类似更改一致。

    • 配置选项 jax_cpu_enable_gloo_collectives 已被弃用。请改用 jax.config.update('jax_cpu_collectives_implementation', 'gloo')

    • jax.Array.device_bufferjax.Array.device_buffers 方法在 JAX v0.4.22 中被弃用后已被移除。请改用 jax.Array.addressable_shardsjax.Array.addressable_data()

    • jax.numpy.whereconditionxy 参数现在仅支持位置参数,遵循 JAX v0.4.21 中对关键词参数的弃用。

    • jax.lax.linalg 中的函数,非数组参数现在必须通过关键字指定。以前,这会引发一个 DeprecationWarning。

    • 在多个 :func:jax.numpy API 中,现在需要类似数组的参数,包括 apply_along_axis()apply_over_axes()inner()outer()cross()kron()lexsort()

  • 错误修复

    • jax.numpy.astype() 现在当 copy=True 时总是返回一个副本。以前,当输出数组与输入数组的 dtype 相同时,不会进行复制。这可能会导致一些内存使用量的增加。默认值设置为 copy=False 以保持向后兼容性。

jaxlib 0.4.27 (2024年5月7日)#

jax 0.4.26 (2024年4月3日)#

  • 新功能

  • 更改

    • 复数值的 jax.numpy.geomspace() 现在选择与 NumPy 2.0 一致的对数螺旋分支。

    • jax.vmap 下,lax.rng_bit_generator 的行为,以及 'rbg''unsafe_rbg' PRNG 实现的行为 已更改,使得映射键只会从批次中的第一个键生成随机数。

    • 文档现在使用 jax.random.key 来构建 PRNG 键数组,而不是 jax.random.PRNGKey

  • 弃用与移除

    • jax.tree_map() 已弃用;请改用 jax.tree.map,或者为了与旧版 JAX 兼容,请使用 jax.tree_util.tree_map()

    • jax.clear_backends() 已被弃用,因为它不一定能如其名所示地执行,并且可能导致意外后果,例如,它不会销毁现有的后端和释放相应的资源。如果你只想清理编译缓存,请使用 jax.clear_caches()。为了向后兼容或你确实需要切换/重新初始化默认后端,请使用 jax.extend.backend.clear_backends()

    • jax.experimental.maps 模块和 jax.experimental.maps.xmap 已被弃用。请使用 jax.experimental.shard_map 或带有 spmd_axis_name 参数的 jax.vmap 来表达 SPMD 设备并行计算。

    • jax.experimental.host_callback 模块已被弃用。请改用 新的 JAX 外部回调。添加了 JAX_HOST_CALLBACK_LEGACY 标志以协助向新回调的过渡。有关讨论,请参见 #20385

    • 传递无法转换为 JAX 数组的参数给 jax.numpy.array_equal()jax.numpy.array_equiv() 现在会导致异常。

    • 已弃用的标志 jax_parallel_functions_output_gda 已被移除。此标志早已被弃用,没有任何作用;它的使用是无操作的。

    • 之前已弃用的导入 jax.interpreters.ad.configjax.interpreters.ad.source_info_util 现在已被移除。请改用 jax.configjax.extend.source_info_util

    • JAX 导出不再支持旧的序列化版本。自2023年10月27日起,版本9已被支持,并自2024年2月1日起成为默认版本。请参阅版本描述。此更改可能会破坏设置低于9的特定JAX序列化版本的客户端。

jaxlib 0.4.26 (2024年4月3日)#

  • 更改

    • JAX 现在仅支持 CUDA 12.1 或更新版本。已放弃对 CUDA 11.8 的支持。

    • JAX 现在支持 NumPy 2.0。

jax 0.4.25 (2024年2月26日)#

  • 新功能

  • 更改

    • Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。你可以通过将 JAX_TRITON_COMPILE_VIA_XLA 环境变量设置为 "0" 来恢复旧的行为。

    • jax.interpreters.xla 中,一些在 v0.4.24 中被移除的已弃用 API 在 v0.4.25 中被重新添加,包括 backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXLAOp。这些仍然被视为已弃用,并且当有更好的替代品可用时,它们将在未来再次被移除。有关讨论,请参阅 #19816

  • 弃用与移除

    • jax.numpy.linalg.solve() 现在对 b.ndim > 1 的批量一维求解显示弃用警告。未来这些将被视为批量二维求解。

    • 现在,将非标量数组转换为 Python 标量会引发错误,无论数组的大小如何。以前,在非标量数组大小为 1 的情况下会引发弃用警告。这与 NumPy 中的类似弃用一致。

    • 之前已弃用的配置API已按照标准的3个月弃用周期被移除(参见 api-兼容性)。这些包括

      • jax.config.config 对象和

      • define_*_stateDEFINE_* 方法是 jax.config 的。

    • 通过 import jax.config 导入 jax.config 子模块已被弃用。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • 最低 jaxlib 版本现在是 0.4.20。

jaxlib 0.4.25 (2024年2月26日)#

jax 0.4.24 (2024年2月6日)#

  • 更改

    • JAX 降低到 StableHLO 不再依赖于物理设备。如果你的原语在降低规则中包装了 custom_partitioning 或 JAX 回调,即传递给 mlir.register_loweringrule 参数的函数,那么请将你的原语添加到 jax._src.dispatch.prim_requires_devices_during_lowering 集合中。这是必要的,因为 custom_partitioning 和 JAX 回调在降低过程中需要物理设备来创建 Sharding。这是一个临时状态,直到我们可以在没有物理设备的情况下创建 Sharding

    • jax.numpy.argsort()jax.numpy.sort() 现在支持 stabledescending 参数。

    • 对形状多态性处理的几项更改(用于 jax.experimental.jax2tfjax.experimental.export):

      • 符号表达式的更整洁美观的打印(#19227

      • 增加了对维度变量指定符号约束的能力。这使得形状多态性更具表现力,并为解决不等式推理中的限制提供了一种方法。参见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 通过添加符号约束(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名字。来自不同作用域的符号表达式不能相互作用,例如在算术运算中。作用域由 jax.experimental.jax2tf.convert()jax.experimental.export.symbolic_shape()jax.experimental.export.symbolic_args_specs() 引入。符号表达式 e 的作用域可以通过 e.scope 读取,并传递给上述函数,以指导它们在给定作用域中构建符号表达式。参见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 简化和更快的相等比较,我们考虑两个符号维度相等,如果它们差值的规范化形式减少到0(#19231;注意这可能导致用户可见的行为变化)

      • 改进了不明确的不等式比较的错误信息 (#19235)。

      • core.non_negative_dim API(最近引入)已被弃用,并引入了 core.max_dimcore.min_dim#18953)来表示符号维度的 maxmin。你可以使用 core.max_dim(d, 0) 来替代 core.non_negative_dim(d)

      • shape_poly.is_poly_dim 已被弃用,取而代之的是 export.is_symbolic_dim (#19282)。

      • export.args_specs 已被弃用,取而代之的是 export.symbolic_args_specs ({jax-issue}#19283`)。

      • shape_poly.PolyShapejax2tf.PolyShape 已被弃用,请使用字符串来指定多态形状(#19284)。

      • JAX 默认的本地序列化版本现在是 9。这对于 jax.experimental.jax2tfjax.experimental.export 是相关的。请参阅 版本号描述

    • 重构了 jax.experimental.export 的 API。现在应该使用 from jax.experimental import export 而不是 from jax.experimental.export import export。旧的导入方式将在接下来的 3 个月内继续有效,作为弃用期。

    • 添加了 jax.scipy.stats.sem()

    • jax.numpy.unique()return_inverse = True 时返回重塑为输入维度大小的逆索引,这与 NumPy 2.0 中 numpy.unique() 的类似更改一致。

    • jax.numpy.sign() 现在对非零复数输入返回 x / abs(x)。这与 numpy.sign() 在 NumPy 版本 2.0 中的行为一致。

    • jax.scipy.special.logsumexp()return_sign=True 时,现在使用 NumPy 2.0 的复数符号约定,即 x / abs(x)。这与 SciPy v1.13 中 scipy.special.logsumexp() 的行为一致。

    • JAX 现在支持 bool DLPack 类型用于导入和导出。之前 bool 值无法导入,并且导出时会被转换为整数。

  • 弃用与移除

    • 一些之前已弃用的函数已被移除,遵循标准的3个月弃用周期(参见 API 兼容性)。这包括:

      • jax.core 中:TracerArrayConversionError, TracerIntegerConversionError, UnexpectedTracerError, as_hashable_function, collections, dtypes, lu, map, namedtuple, partial, pp, ref, safe_zip, safe_map, source_info_util, total_ordering, traceback_util, tuple_delete, tuple_insert, 和 zip

      • 来自 jax.laxdtypesitertoolsnaryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule

      • jax.linear_util 子模块及其所有内容。

      • jax.prng 子模块及其所有内容。

      • 来自 jax.randomPRNGKeyArrayKeyArraydefault_prng_implthreefry_2x32threefry2x32_keythreefry2x32_prbg_keyunsafe_rbg_key

      • 来自 jax.tree_utilregister_keypathsAttributeKeyPathEntryGetItemKeyPathEntry

      • jax.interpreters.xla 中:backend_specific_translations, translations, register_translation, xla_destructure, TranslationRule, TranslationContext, axis_groups, ShapedArray, ConcreteArray, AxisEnv, backend_compile, 和 XLAOp

      • 来自 jax.numpy 的:NINF, NZERO, PZERO, row_stack, issubsctype, trapz, 和 in1d

      • 来自 jax.scipy.linalgtriltriu

    • 之前已弃用的方法 PRNGKeyArray.unsafe_raw_array 已被移除。请改用 jax.random.key_data()

    • bool(empty_array) 现在会引发错误,而不是返回 False。此前会引发一个弃用警告,这与 NumPy 中的类似更改一致。

    • 对 mhlo MLIR 方言的支持已被弃用。JAX 不再使用 mhlo 方言,转而使用 stablehlo。未来将移除引用“mhlo”的 API。请改用“stablehlo”方言。

    • jax.random: 直接将批量键传递给随机数生成函数,如 bits()gamma() 等,已被弃用,并将发出 FutureWarning。请使用 jax.vmap 进行显式批处理。

    • jax.lax.tie_in() 已被弃用:自 JAX v0.2.0 以来,它一直是一个无操作指令。

jaxlib 0.4.24 (2024年2月6日)#

  • 更改

    • JAX 现在支持 CUDA 12.3 和 CUDA 11.8。对 CUDA 12.2 的支持已被移除。

    • cost_analysis 现在可以与交叉编译的 Compiled 对象一起工作(例如,当使用 .lower().compile() 与拓扑对象一起时,例如从非 TPU 计算机编译到 Cloud TPU)。

    • 添加了对 CUDA 数组接口 的导入支持(需要 jax 0.4.25)。

jax 0.4.23 (2023年12月13日)#

jaxlib 0.4.23 (2023年12月13日)#

  • 修复了在编译过程中导致GPU编译器产生冗长日志的错误。

jax 0.4.22 (2023年12月13日)#

  • 弃用

    • JAX 数组的 device_bufferdevice_buffers 属性已被弃用。显式缓冲区已被更灵活的数组分片接口所取代,但可以通过以下方式恢复之前的输出:

      • arr.device_buffer 变为 arr.addressable_data(0)

      • arr.device_buffers 变为 [x.data for x in arr.addressable_shards]

jaxlib 0.4.22 (2023年12月13日)#

jax 0.4.21 (2023年12月4日)#

  • 新功能

  • 更改

    • 最小的 jaxlib 版本现在是 0.4.19。

    • 发布的轮子现在使用 clang 而不是 gcc 构建。

    • 在调用 jax.distributed.initialize() 之前,确保设备后端未被初始化。

    • 在云TPU环境中自动传递参数给 jax.distributed.initialize()

  • 弃用

    • 之前已弃用的 sym_pos 参数已从 jax.scipy.linalg.solve() 中移除。请改用 assume_a='pos'

    • None 传递给 jax.array()jax.asarray(),无论是直接传递还是作为列表或元组的一部分,都已弃用,现在会引发 FutureWarning。目前它会被转换为 NaN,未来将会引发 TypeError

    • 通过关键字参数传递 conditionxy 参数给 jax.numpy.where 已被弃用,以匹配 numpy.where

    • 传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 的参数如果无法转换为 JAX 数组,已被弃用,现在会引发 DeprecationWaning。目前这些函数会返回 False,未来这将引发异常。

    • JAX 数组的 device() 方法已被弃用。根据上下文,它可能被以下之一替换:

      • jax.Array.devices() 返回数组使用的所有设备的集合。

      • jax.Array.sharding 给出了数组使用的分片配置。

jaxlib 0.4.21 (2023年12月4日)#

  • 更改

    • 在准备添加分布式CPU支持时,JAX现在将CPU设备与GPU和TPU设备同等对待,即:

      • jax.devices() 包括分布式作业中所有存在的设备,即使是那些不属于当前进程的设备。jax.local_devices() 仍然只包括当前进程本地的设备,所以如果 jax.devices() 的更改对你造成了影响,你很可能需要改用 jax.local_devices()

      • CPU 设备现在在分布式作业中接收一个全局唯一的 ID 号;以前 CPU 设备会接收一个进程本地的 ID 号。

      • 每个CPU设备的process_index现在将与同一进程内的任何GPU或TPU设备匹配;之前CPU设备的process_index始终为0。

    • 在 NVIDIA GPU 上,JAX 现在优先选择 Jacobi SVD 求解器来处理大小不超过 1024x1024 的矩阵。Jacobi 求解器似乎比非 Jacobi 版本更快。

  • 错误修复

    • 修复了当传递包含非有限值的数组到非对称特征分解时出现的错误/挂起问题(#18226)。现在包含非有限值的数组会产生充满NaN的输出数组。

jax 0.4.20 (2023年11月2日)#

jaxlib 0.4.20 (2023年11月2日)#

  • 错误修复

    • 修复了 E4M3 和 E5M2 float8 类型之间的类型混淆问题。

jax 0.4.19 (2023年10月19日)#

  • 新功能

    • 添加了 jax.typing.DTypeLike,它可以用于注释可转换为 JAX dtypes 的对象。

    • 添加了 jax.numpy.fill_diagonal

  • 更改

    • JAX 现在需要 SciPy 1.9 或更新版本。

  • 错误修复

    • 在多控制器分布式JAX程序中,只有进程0会写入持久的编译缓存条目。如果缓存放置在GCS等网络文件系统上,这将修复写入争用问题。

    • 在确定已安装的 cusolver 和 cufft 库版本是否至少与 JAX 构建时所针对的版本一样新时,不再考虑补丁版本。

jaxlib 0.4.19 (2023年10月19日)#

  • 更改

    • jaxlib 现在将始终优先选择通过 pip 安装的 NVIDIA CUDA 库(nvidia-… 包),如果它们已安装,则优先于任何其他 CUDA 安装,包括在 LD_LIBRARY_PATH 中命名的安装。如果这导致问题,并且意图是使用系统安装的 CUDA,解决方法是移除通过 pip 安装的 CUDA 库包。

jax 0.4.18 (2023年10月6日)#

jaxlib 0.4.18 (2023年10月6日)#

  • 更改

    • CUDA jaxlibs 现在依赖用户安装兼容的 NCCL 版本。如果使用推荐的 cuda12_pip 安装,NCCL 应该会自动安装。目前,需要 NCCL 2.16 或更新版本。

    • 我们现在提供Linux aarch64的轮子,无论是否支持NVIDIA GPU。

    • jax.Array.item() 现在支持可选的索引参数。

  • 弃用

    • jax.lax 中,一些内部工具和无意中导出的内容已被弃用,并将在未来的版本中移除。

      • jax.lax.dtypes: 请改用 jax.dtypes

      • jax.lax.itertools: 请使用 itertools 代替。

      • naryop, naryop_dtype_rule, standard_abstract_eval, standard_naryop, standard_primitive, standard_unop, unop, 和 unop_dtype_rule 是内部工具,现已弃用且没有替代品。

  • 错误修复

    • 修复了 Cloud TPU 回归问题,由于 smem 导致编译时内存不足。

jax 0.4.17 (2023年10月3日)#

  • 新功能

  • 弃用

    • 移除了已弃用的模块 jax.abstract_arrays 及其所有内容。

    • jax.random 中的命名键构造函数已被弃用。请改为传递 impl 参数给 jax.random.PRNGKey()jax.random.key()

      • random.threefry2x32_key(seed) 变为 random.PRNGKey(seed, impl='threefry2x32')

      • random.rbg_key(seed) 变为 random.PRNGKey(seed, impl='rbg')

      • random.unsafe_rbg_key(seed) 变为 random.PRNGKey(seed, impl='unsafe_rbg')

  • 更改:

    • CUDA: JAX 现在会验证它找到的 CUDA 库至少与 JAX 构建时使用的 CUDA 库一样新。如果找到较旧的库,JAX 会引发异常,因为这比神秘的失败和崩溃更可取。

    • 移除了“未找到 GPU/TPU”警告。改为在 Linux 上,如果检测到 NVIDIA GPU 或 Google TPU 但未使用,并且未指定 --jax_platforms,则发出警告。

    • jax.scipy.stats.mode() 现在如果在大小为0的轴上计算众数,将返回一个0的计数,这与SciPy 1.11中scipy.stats.mode的行为相匹配。

    • 大多数 jax.numpy 函数和属性现在都有完全定义的类型存根。以前,许多这些函数和属性在静态类型检查器如 mypypytype 中被视为 Any

jaxlib 0.4.17 (2023年10月3日)#

  • 更改:

    • Python 3.12 的 wheel 包在此版本中被添加。

    • CUDA 12 轮子现在需要 CUDA 12.2 或更新版本,以及 cuDNN 8.9.4 或更新版本。

  • 错误修复:

    • 修复了在初始化JAX CPU后端时来自ABSL的固定日志垃圾信息。

jax 0.4.16 (2023年9月18日)#

  • 更改

    • 添加了 jax.numpy.ufunc,以及 jax.numpy.frompyfunc(),它可以将任何标量值函数转换为类似于 numpy.ufunc() 的对象,具有 outer()reduce()accumulate()at()reduceat() 等方法(#17054)。

    • 添加了 jax.scipy.integrate.trapezoid()

    • 当不在 IPython 下运行时:当引发异常时,JAX 现在会从回溯中过滤掉其内部帧的全部内容。(不再显示之前出现的“未过滤的堆栈跟踪”。)这将产生更友好的回溯。请参见此处查看示例。此行为可以通过设置 JAX_TRACEBACK_FILTERING=remove_frames(用于两个独立的未过滤/过滤的回溯,这是旧的行为)或 JAX_TRACEBACK_FILTERING=off(用于一个未过滤的回溯)来更改。

    • jax2tf 默认序列化版本现在是 7,它引入了新的形状 安全断言

    • 传递给 jax.sharding.Mesh 的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices() 已经是可哈希的。

  • 重大变更:

    • jax2tf 现在默认使用原生序列化。详情和覆盖默认机制请参见 jax2tf 文档

    • 选项 --jax_coordination_service 已被移除。它现在总是 True

    • jax.jaxpr_util 已从公共 JAX 命名空间中移除。

    • JAX_USE_PJRT_C_API_ON_TPU 不再起作用(即它总是默认为真)。

    • 在2021年12月引入的向后兼容标志 --jax_host_callback_ad_transforms 已被移除。

  • 弃用:

    • 以下 jax.numpy API 已根据 NumPy NEP-52 被弃用:

      • jax.numpy.NINF 已被弃用。请改用 -jax.numpy.inf

      • jax.numpy.PZERO 已被弃用。请改用 0.0

      • jax.numpy.NZERO 已被弃用。请改用 -0.0

      • jax.numpy.issubsctype(x, t) 已被弃用。请使用 jax.numpy.issubdtype(x.dtype, t)

      • jax.numpy.row_stack 已被弃用。请改用 jax.numpy.vstack

      • jax.numpy.in1d 已被弃用。请改用 jax.numpy.isin

      • jax.numpy.trapz 已被弃用。请改用 jax.scipy.integrate.trapezoid

    • jax.scipy.linalg.triljax.scipy.linalg.triu 已弃用,跟随 SciPy 的步伐。请改用 jax.numpy.triljax.numpy.triu

    • jax.lax.prod 在 JAX v0.4.11 中被弃用后已被移除。请改用内置的 math.prod

    • 一些与定义自定义 JAX 原语的 HLO 降低规则相关的 jax.interpreters.xla 导出已被弃用。自定义原语应使用 jax.interpreters.mlir 中的 StableHLO 降低工具来定义。

    • 以下在三个月弃用期后已被移除的先前弃用函数:

      • jax.abstract_arrays.ShapedArray: 使用 jax.core.ShapedArray

      • jax.abstract_arrays.raise_to_shaped: 使用 jax.core.raise_to_shaped

      • jax.numpy.alltrue: 使用 jax.numpy.all

      • jax.numpy.sometrue: 使用 jax.numpy.any

      • jax.numpy.product: 使用 jax.numpy.prod

      • jax.numpy.cumproduct: 使用 jax.numpy.cumprod

  • 弃用/移除:

    • 内部子模块 jax.prng 现已弃用。其内容可在 jax.extend.random 中找到。

    • 内部子模块路径 jax.linear_util 已被弃用。请改用 jax.extend.linear_util(属于 jax.extend: 一个用于扩展的模块 的一部分)

    • jax.random.PRNGKeyArrayjax.random.KeyArray 已被弃用。请使用 jax.Array 进行类型注解,并使用 jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) 进行类型化 prng 键的运行时检测。

    • 方法 PRNGKeyArray.unsafe_raw_array 已被弃用。请改用 jax.random.key_data()

    • jax.experimental.pjit.with_sharding_constraint 已弃用。请改用 jax.lax.with_sharding_constraint

    • 内部工具 jax.core.is_opaque_dtypejax.core.has_opaque_dtype 已被移除。不透明数据类型已重命名为扩展数据类型;请改用 jnp.issubdtype(dtype, jax.dtypes.extended)(自 jax v0.4.14 起可用)。

    • 实用工具 jax.interpreters.xla.register_collective_primitive 已被移除。此实用工具在最近的 JAX 版本中没有任何实际作用,调用它可以安全地移除。

    • 内部子模块路径 jax.linear_util 已被弃用。请改用 jax.extend.linear_util(属于 jax.extend: 一个用于扩展的模块 的一部分)

jaxlib 0.4.16 (2023年9月18日)#

  • 更改:

    • 通过实验性 jax 稀疏 API 进行的稀疏 CSR 矩阵乘法在 NVIDIA GPU 上不再使用确定性算法。此更改是为了提高与 CUDA 12.2.1 的兼容性。

  • 错误修复:

    • 修复了由于与乱序节和 IMAGE_REL_AMD64_ADDR32NB 重定位相关的致命 LLVM 错误导致的 Windows 崩溃问题(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。

jax 0.4.14 (2023年7月27日)#

  • 更改

    • jax.jit 接受 donate_argnames 作为参数。它的语义与 static_argnames 相似。如果既没有提供 donate_argnums 也没有提供 donate_argnames,则不会捐赠任何参数。如果提供了 donate_argnames 但没有提供 donate_argnums,或者反之,JAX 使用 inspect.signature(fun) 来查找与 donate_argnames 对应的任何位置参数(或反之)。如果同时提供了 donate_argnums 和 donate_argnames,则不会使用 inspect.signature,并且只会捐赠在 donate_argnums 或 donate_argnames 中列出的实际参数。

    • jax.random.gamma() 已被重构为一个更高效的算法,具有更稳健的端点行为(#16779)。这意味着对于给定的 key,在 JAX v0.4.13 和 v0.4.14 之间,gamma 及相关采样器(包括 jax.random.ball()jax.random.beta()jax.random.chisquare()jax.random.dirichlet()jax.random.generalized_normal()jax.random.loggamma()jax.random.t())返回的值序列将会改变。

  • 删除

    • in_axis_resourcesout_axis_resources 被弃用以来已经超过3个月,因此它们已从 pjit 中删除。请使用 in_shardingsout_shardings 作为替代。这是一个安全和简单的名称替换。它不会改变任何当前 pjit 的语义,也不会破坏任何代码。您仍然可以将 PartitionSpecs 传递给 in_shardings 和 out_shardings。

  • 弃用

    • 根据 https://jax.readthedocs.io/en/latest/deprecation.html ,Python 3.8 的支持已被放弃。

    • 根据 https://jax.readthedocs.io/en/latest/deprecation.html ,JAX 现在要求 NumPy 1.22 或更新版本。

    • 通过位置传递 jax.numpy.ndarray.at() 的可选参数不再支持,这在 JAX 版本 0.4.7 中已被弃用。例如,不要使用 x.at[i].get(True),而应使用 x.at[i].get(indices_are_sorted=True)

    • 以下 jax.Array 方法在 JAX v0.4.5 中被弃用后已被移除:

    • 以下API在之前的弃用后已被移除:

      • jax.ad: 使用 jax.interpreters.ad

      • jax.curry: 使用 curry = lambda f: partial(partial, f).

      • jax.partial_eval: 使用 jax.interpreters.partial_eval

      • jax.pxla: 使用 jax.interpreters.pxla

      • jax.xla: 使用 jax.interpreters.xla

      • jax.ShapedArray: 使用 jax.core.ShapedArray

      • jax.interpreters.pxla.device_put: 使用 jax.device_put()

      • jax.interpreters.pxla.make_sharded_device_array: 使用 jax.make_array_from_single_device_arrays()

      • jax.interpreters.pxla.ShardedDeviceArray: 使用 jax.Array

      • jax.numpy.DeviceArray: 使用 jax.Array

      • jax.stages.Compiled.compiler_ir: 使用 jax.stages.Compiled.as_text()

  • 重大变更

    • JAX 现在需要 ml_dtypes 版本 0.2.0 或更新版本。

    • 为了修复一个特殊情况,调用带有五个参数的 jax.lax.cond() 将始终解析为“公共操作数” cond 行为(如文档所述),如果第二个和第三个参数是可调用的,即使其他操作数也是可调用的。参见 #16413

    • 已弃用的配置选项 jax_arrayjax_jit_pjit_api_merge 已被移除,这些选项没有任何作用。这些选项在多个版本中默认情况下一直为真。

  • 新功能

    • JAX 现在支持一个配置标志 –jax_serialization_version 和一个 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本(#16746)。

    • jax2tf 在存在形状多态性的情况下,如果序列化版本至少为7,现在会生成检查某些形状约束的代码。参见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。

jaxlib 0.4.14 (2023年7月27日)#

  • 弃用

    • 根据 https://jax.readthedocs.io/en/latest/deprecation.html ,Python 3.8 的支持已被放弃。

jax 0.4.13 (2023年6月22日)#

  • 更改

    • jax.jit 现在允许将 None 传递给 in_shardingsout_shardings。其语义如下:

      • 对于 in_shardings,JAX 会将其标记为复制,但这种行为在未来可能会改变。

      • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

    • jax.experimental.pjit.pjit 也允许将 None 传递给 in_shardingsout_shardings。其语义如下:

      • 如果未提供网格上下文管理器,JAX 可以自由选择任何它想要的拆分方式。

        • 对于 in_shardings,JAX 会将其标记为复制,但这种行为在未来可能会改变。

        • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

      • 如果提供了网格上下文管理器,None 将意味着该值将在网格的所有设备上复制。

    • Executable.cost_analysis() 适用于 Cloud TPU

    • 如果使用了非允许列表中的 jaxlib 插件,则添加了一个警告。

    • 添加了 jax.tree_util.tree_leaves_with_path

    • None 不是 jax.experimental.multihost_utils.host_local_array_to_global_arrayjax.experimental.multihost_utils.global_array_to_host_local_array 的有效输入。如果你想复制你的输入,请使用 jax.sharding.PartitionSpec()

  • 错误修复

    • 在CUDA 12版本中修复了错误的wheel名称(#16362);正确的wheel名称应为 cudnn89 而不是 cudnn88

  • 弃用

    • native_serialization_strict_checks 参数在 jax.experimental.jax2tf.convert() 中已被弃用,取而代之的是新的 native_serializaation_disabled_checks (#16347)。

jaxlib 0.4.13 (2023年6月22日)#

  • 更改

    • jaxlib Pypi 发布中添加了仅限 Windows CPU 的轮子。

  • 错误修复

    • __cuda_array_interface__ 在之前的 jaxlib 版本中存在问题,现已修复 (#16440)。

    • 并发CUDA内核追踪现在在NVIDIA GPU上默认启用。

jax 0.4.12 (2023年6月8日)#

  • 更改

  • 弃用

    • jax.abstract_arrays 及其内容现已弃用。请参阅 :mod:jax.core 中的相关功能。

    • jax.numpy.alltrue: 使用 jax.numpy.all。这遵循了 NumPy 版本 1.25.0 中 numpy.alltrue 的弃用。

    • jax.numpy.sometrue:使用 jax.numpy.any。这遵循了 numpy.sometrue 在 NumPy 版本 1.25.0 中的弃用。

    • jax.numpy.product: 使用 jax.numpy.prod。这遵循了 numpy.product 在 NumPy 版本 1.25.0 中的弃用。

    • jax.numpy.cumproduct: 使用 jax.numpy.cumprod。这遵循了 NumPy 版本 1.25.0 中 numpy.cumproduct 的弃用。

    • jax.sharding.OpShardingSharding 已被移除,因为它已被弃用3个月。

jaxlib 0.4.12 (2023年6月8日)#

  • 更改

    • 包含适用于 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。之前版本的 jaxlib 应该也能在 Hopper 上工作,但在第一次执行 JAX 操作时会有较长的 JIT 编译延迟。

  • 错误修复

    • 修复了在Python 3.11下JAX生成的Python回溯中源行信息不正确的问题。

    • 修复了在打印JAX生成的Python回溯中帧的局部变量时发生的崩溃问题 (#16027)。

jax 0.4.11 (2023年5月31日)#

  • 弃用

    • 根据 api-兼容性 政策,以下API在3个月的弃用期后已被移除:

      • jax.experimental.PartitionSpec: 使用 jax.sharding.PartitionSpec

      • jax.experimental.maps.Mesh: 使用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding: 使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec: 使用 jax.sharding.PartitionSpec

      • jax.experimental.pjit.FROM_GDA。改为传递分片的 jax.Array 对象作为输入,并移除 pjit 的可选 in_shardings 参数。

      • jax.interpreters.pxla.PartitionSpec: 使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh: 使用 jax.sharding.Mesh

      • jax.interpreters.xla.Buffer:使用 jax.Array

      • jax.interpreters.xla.Device: 使用 jax.Device

      • jax.interpreters.xla.DeviceArray: 使用 jax.Array

      • jax.interpreters.xla.device_put: 使用 jax.device_put

      • jax.interpreters.xla.xla_call_p: 使用 jax.experimental.pjit.pjit_p

      • with_sharding_constraintaxis_resources 参数已被移除。请改用 shardings

jaxlib 0.4.11 (2023年5月31日)#

  • 更改

    • Device 添加了 memory_stats() 方法。如果支持,这将返回一个字符串状态名称与整数值的字典,例如 "bytes_in_use",或者如果平台不支持内存统计则返回 None。返回的确切状态可能因平台而异。目前仅在 Cloud TPU 上实现。

    • 重新添加了对CPU设备上Python缓冲协议(memoryview)的支持。

jax 0.4.10 (2023年5月11日)#

jaxlib 0.4.10 (2023年5月11日)#

  • 更改

    • 修复了 'apple-m1' 不是此目标的已识别处理器(忽略处理器) 问题,该问题导致之前的版本无法在 Mac M1 上运行。

jax 0.4.9 (2023年5月9日)#

  • 更改

    • 已移除标志 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap。它们现在总是开启的。

    • 在TPU上进行奇异值分解(SVD)的准确性已得到提升(需要 jaxlib 0.4.9)。

  • 弃用

    • jax.experimental.gda_serialization 已被弃用并更名为 jax.experimental.array_serialization。请更改您的导入以使用 jax.experimental.array_serialization

    • pjit 的 in_axis_resourcesout_axis_resources 参数已被弃用。请分别使用 in_shardingsout_shardings

    • 函数 jax.numpy.msort 已被移除。自 JAX v0.4.1 起已被弃用。请改用 jnp.sort(a, axis=0)

    • sharded_jitsharded_jit 早已不再使用以来,in_partsout_parts 参数已从 jax.xla_computation 中移除。

    • instantiate_const_outputs 参数已从 jax.xla_computation 中移除,因为它已经很长时间未被使用。

jaxlib 0.4.9 (2023年5月9日)#

jax 0.4.8 (2023年3月29日)#

  • 重大变更

    • Cloud TPU 运行时的一个主要组件已升级。这使得 Cloud TPU 上启用了以下新功能:

      jax.experimental.host_callback() 在新运行时组件下不再支持在云TPU上使用。如果新的 jax.debug API 无法满足您的使用需求,请在 JAX问题追踪器 提交问题。

      旧的运行时组件将通过设置环境变量 JAX_USE_PJRT_C_API_ON_TPU=false 在接下来的至少三个月内可用。如果你发现由于任何原因需要禁用新运行时,请在 JAX issue tracker 上告知我们。

  • 更改

    • jaxlib 的最低版本已从 0.4.6 提升至 0.4.7。

  • 弃用

    • CUDA 11.4 支持已被移除。JAX GPU 轮子仅支持 CUDA 11.8 和 CUDA 12。旧版本的 CUDA 如果从源码构建 jaxlib 可能仍然有效。

    • global_arg_shapes 参数仅在 sharded_jit 中与 pmap 一起工作,并已从 pmap 中移除。请迁移到 pjit 并从 pmap 中移除 global_arg_shapes。

jax 0.4.7 (2023年3月27日)#

  • 更改

    • 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration jax.config.jax_array 不能再被禁用。

    • jax.config.jax_jit_pjit_api_merge 不能再被禁用了。

    • jax.experimental.jax2tf.convert() 现在支持 native_serialization 参数,以使用 JAX 的本地降低到 StableHLO,从而为整个 JAX 函数获取一个 StableHLO 模块,而不是将每个 JAX 原语降低为一个 TensorFlow 操作。这简化了内部结构,并增加了你序列化的内容与 JAX 本地语义匹配的信心。参见 文档。作为此更改的一部分,配置标志 --jax2tf_default_experimental_native_lowering 已重命名为 --jax2tf_native_serialization

    • JAX 现在依赖于 ml_dtypes,其中包含了像 bfloat16 这样的 NumPy 类型的定义。这些定义之前是 JAX 内部的,但为了方便与其他项目共享,已被拆分为一个单独的包。

    • JAX 现在需要 NumPy 1.21 或更新版本,以及 SciPy 1.7 或更新版本。

  • 弃用

    • 类型 jax.numpy.DeviceArray 已被弃用。请改用 jax.Array,它是 jax.Array 的别名。

    • 类型 jax.interpreters.pxla.ShardedDeviceArray 已被弃用。请改用 jax.Array

    • 通过位置向 jax.numpy.ndarray.at() 传递额外参数已被弃用。例如,不要使用 x.at[i].get(True),而应使用 x.at[i].get(indices_are_sorted=True)

    • jax.interpreters.xla.device_put 已弃用。请使用 jax.device_put

    • jax.interpreters.pxla.device_put 已被弃用。请使用 jax.device_put

    • jax.experimental.pjit.FROM_GDA 已被弃用。请将分片的 jax.Arrays 作为输入传递,并删除 in_shardings 参数,因为它现在是可选的。

jaxlib 0.4.7 (2023年3月27日)#

更改:

  • jaxlib 现在依赖于 ml_dtypes,其中包含了诸如 bfloat16 这样的 NumPy 类型定义。这些定义之前是 JAX 内部的,但为了方便与其他项目共享,已被拆分为一个单独的包。

jax 0.4.6 (2023年3月9日)#

  • 更改

    • jax.tree_util 现在包含一组API,允许用户为其自定义pytree节点定义键。这包括:

      • tree_flatten_with_path 用于展平树结构,并返回不仅每个叶子节点,还有它们的路径键。

      • tree_map_with_path 可以映射一个函数,该函数将键路径作为参数。

      • register_pytree_with_keys 用于注册自定义 pytree 节点中键路径和叶子应如何显示。

      • keystr 用于美化打印一个键路径。

    • jax2tf.call_tf() 有一个新参数 output_shape_dtype(默认 None),可用于声明结果的输出形状和类型。这使得 jax2tf.call_tf() 在存在形状多态性的情况下也能工作。(#14734)。

  • 弃用

    • jax.tree_util 中的旧键路径API已被弃用,并将于2023年3月10日起3个月后移除:

jaxlib 0.4.6 (2023年3月9日)#

jax 0.4.5 (2023年3月2日)#

  • 弃用

    • jax.sharding.OpShardingSharding 已重命名为 jax.sharding.GSPMDShardingjax.sharding.OpShardingSharding 将于2023年2月17日起3个月后移除。

    • 以下 jax.Array 方法已被弃用,并将于2023年2月23日后的3个月内移除:

jax 0.4.4 (2023年2月16日)#

  • 更改

    • jitpjit 的实现已经合并。合并 jit 和 pjit 改变了 JAX 的内部结构,但不会影响 JAX 的公共 API。之前,jit 是一个最终风格的原语。最终风格意味着 jaxpr 的创建被尽可能地延迟,并且转换被堆叠在一起。随着 jit-pjit 实现合并,jit 变成了一个初始风格的原语,这意味着我们尽可能早地跟踪到 jaxpr。更多信息请参见 autodidax 中的这一部分。转向初始风格应该简化 JAX 的内部结构,并使动态形状等功能的开发更容易。您只能通过环境变量禁用它,即 os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。由于合并会影响 JAX 在导入时的行为,因此需要在导入 jax 之前通过环境变量禁用合并。

    • with_sharding_constraintaxis_resources 参数已被弃用。请改用 shardings。如果你之前将 axis_resources 用作参数,则无需更改。如果你之前将其用作关键字参数,则请改用 shardingsaxis_resources 将于2023年2月13日后的3个月内移除。

    • 添加了 jax.typing 模块,提供了对 JAX 函数进行类型注解的工具。

    • 以下名称已被弃用:

      • jax.xla.Devicejax.interpreters.xla.Device: 使用 jax.Device

      • jax.experimental.maps.Mesh。请改用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding: 使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec: 使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh: 使用 jax.sharding.Mesh

      • jax.interpreters.pxla.PartitionSpec: 使用 jax.sharding.PartitionSpec

  • 重大变更

    • 对于像 :func:jax.numpy.sum 这样的归约函数,initial 参数现在必须是一个标量,这与相应的 NumPy API 一致。之前对非标量 initial 值进行广播输出的行为是一个无意的实现细节 (#14446)。

jaxlib 0.4.4 (2023年2月16日)#

  • 重大变更

    • 默认的 jaxlib 构建中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以通过源码构建 jaxlib 来支持 Kepler(通过 build.py--cuda_compute_capabilities=sm_35 选项),但请注意 CUDA 12 已完全放弃对 Kepler GPU 的支持。

jax 0.4.3 (2023年2月8日)#

jaxlib 0.4.3 (2023年2月8日)#

  • jax.Array 现在有了非阻塞的 is_ready() 方法,如果数组准备就绪,则返回 True(另见 jax.block_until_ready())。

jax 0.4.2 (2023年1月24日)#

  • 重大变更

    • 删除了 jax.experimental.callback

    • 在存在 jax2tf 形状多态性的情况下,通过将符号维度转换为 JAX 数组,维度操作已推广到更多场景中。涉及符号维度和 np.ndarray 的操作现在可以在结果用作形状值时引发错误(#14106)。

    • jaxpr 对象现在在设置属性时会引发错误,以避免有问题的突变 (#14102)

  • 更改

    • jax2tf.call_tf() 有一个新参数 has_side_effects(默认值为 True),可用于声明实例是否可以被 JAX 优化(如死代码消除)移除或复制(#13980)。

    • 为 jax2tf 形状多态性增加了对 floordiv 和 mod 的更多支持。之前,某些除法操作在存在符号维度时会导致错误(#14108)。

jaxlib 0.4.2 (2023年1月24日)#

  • 更改

    • 设置 JAX_USE_PJRT_C_API_ON_TPU=1 以启用新的 Cloud TPU 运行时,具有自动设备内存碎片整理功能。

jax 0.4.1 (2022年12月13日)#

  • 更改

    • 根据 JAX 的 版本支持策略,Python 3.7 的支持已被取消。

    • 我们引入了 jax.Array,这是一种统一的数组类型,它包含了 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 类型。jax.Array 类型有助于使并行成为 JAX 的核心功能,简化和统一 JAX 的内部结构,并允许我们统一 jitpjitjax.Array 在 JAX 0.4 中已默认启用,并对 pjit API 进行了一些重大更改。jax.Array 迁移指南 可以帮助你将代码库迁移到 jax.Array。你还可以查看 分布式数组和自动并行化 教程以理解新概念。

    • PartitionSpecMesh 现已脱离实验阶段。新的 API 端点是 jax.sharding.PartitionSpecjax.sharding.Meshjax.experimental.maps.Meshjax.experimental.PartitionSpec 已被弃用,并将在 3 个月后移除。

    • with_sharding_constraint 的新公共端点是 jax.lax.with_sharding_constraint

    • 如果同时使用 ABSL 标志和 jax.config,在 JAX 配置选项从 ABSL 标志初始填充后,ABSL 标志值将不再被读取或写入。这一更改提高了读取 jax.config 选项的性能,这些选项在 JAX 中被广泛使用。

    • jax2tf.call_tf 函数现在使用与嵌入 JAX 计算相同的平台的首个 TF 设备进行 TF 降级。之前,它使用的是 JAX 默认后端的 0 号设备。

    • 一些 jax.numpy 函数的参数现在被标记为仅位置参数,与 NumPy 匹配。

    • jnp.msort 现在已被弃用,与 numpy 1.24 中 np.msort 的弃用一致。根据 API 兼容性 政策,它将在未来的版本中被移除。可以用 jnp.sort(a, axis=0) 替代。

jaxlib 0.4.1 (2022年12月13日)#

  • 更改

    • 根据 JAX 的 版本支持策略,Python 3.7 的支持已被取消。

    • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX 的行为已更改为分配总 GPU 内存的 XX%,而不是以前使用当前可用 GPU 内存来计算预分配的行为。更多详情请参阅 GPU 内存分配

    • 已弃用的方法 .block_host_until_ready() 已被移除。请改用 .block_until_ready()

jax 0.4.0 (2022年12月12日)#

  • 该版本已被撤回。

jaxlib 0.4.0 (2022年12月12日)#

  • 该版本已被撤回。

jax 0.3.25 (2022年11月15日)#

jaxlib 0.3.25 (2022年11月15日)#

  • 更改

    • 增加了对CPU和GPU上三对角化简的支持。

    • 增加了对CPU上上Hessenberg约简的支持。

  • 漏洞

    • 修复了一个错误,该错误意味着在 Python 3.10+ 下,JAX 捕获的回溯帧错误地映射到源代码行。

jax 0.3.24 (2022年11月4日)#

  • 更改

    • JAX 的导入速度应该会更快。我们现在延迟导入 scipy,这占用了 JAX 导入时间的一大部分。

    • 设置环境变量 JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N 可以用来限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过1秒的计算将被缓存。

    • 如果在 TPU 上使用 pmap 时未指定顺序,默认设备顺序现在与单进程作业的 jax.devices() 匹配。以前两者的顺序不同,这可能导致不必要的复制或内存不足错误。要求顺序一致简化了问题。

  • 重大变更

  • 弃用

    • jax.sharding.MeshPspecSharding 已重命名为 jax.sharding.NamedShardingjax.sharding.MeshPspecSharding 名称将在3个月后移除。

jaxlib 0.3.24 (2022年11月4日)#

  • 更改

    • 缓冲区捐赠现在支持CPU。这可能会破坏那些在CPU上标记缓冲区进行捐赠但依赖于捐赠未实现的代码。

jax 0.3.23 (2022年10月12日)#

  • 更改

    • 更新 Colab TPU 驱动版本以适应新的 jaxlib 发布。

jax 0.3.22 (2022年10月11日)#

  • 更改

    • 在TPU初始化时,添加 JAX_PLATFORMS=tpu,cpu 作为默认设置,这样如果TPU无法初始化,JAX将抛出错误而不是回退到CPU。设置 JAX_PLATFORMS='' 以覆盖此行为并自动选择可用的后端(原始默认值),或者设置 JAX_PLATFORMS=cpu 以始终使用CPU,无论TPU是否可用。

  • 弃用

    • 在 JAX v0.3.8 中被弃用的几个测试工具现在已从 jax.test_util 中移除。

jaxlib 0.3.22 (2022年10月11日)#

jax 0.3.21 (2022年9月30日)#

  • GitHub 提交

  • 更改

    • 持久编译缓存现在会在出错时发出警告,而不是抛出异常(#12582),因此如果缓存出现问题,程序执行可以继续。设置 JAX_RAISE_PERSISTENT_CACHE_ERRORS=true 以恢复此行为。

jax 0.3.20 (2022年9月28日)#

  • 错误修复:

    • 添加了之前版本中缺失的 .pyi 文件(#12536)。

    • 修复了 jax 0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。

    • 修复 setup.py 注释中的 pip 错误 URL (#12528)。

jaxlib 0.3.20 (2022年9月28日)#

  • GitHub 提交

  • 错误修复

    • 修复了通过 jax_cuda_visible_devices 在分布式作业中限制可见 CUDA 设备的支持。此功能对于 JAX/SLURM 在 GPU 上的集成是必需的(#12533)。

jax 0.3.19 (2022年9月27日)#

jax 0.3.18 (2022年9月26日)#

  • GitHub 提交

  • 更改

    • 提前降低和编译功能(在 #7733 中跟踪)是稳定且公开的。请参阅 概述jax.stages 的 API 文档。

    • 引入了 jax.Array,旨在用于 JAX 中数组类型的 isinstance 检查和类型注解。请注意,这包括了对 jax-内部对象的 jax.numpy.ndarrayisinstance 工作方式的一些细微变化,因为 jax.numpy.ndarray 现在是 jax.Array 的一个简单别名。

  • 重大变更

    • jax._src 不再导入到公共的 jax 命名空间中。这可能会破坏使用 JAX 内部功能的用户。

    • jax.soft_pmap 已被删除。请改用 pjitxmapjax.soft_pmap 未被文档化。如果它被文档化,将会提供一个弃用期。

jax 0.3.17 (2022年8月31日)#

  • GitHub 提交

  • 漏洞

    • 修复了 lax.pow 中指数为零时的边缘情况问题(#12041

  • 重大变更

    • jax.checkpoint(),也称为 jax.remat(),不再支持 concrete 选项,这是遵循先前版本的弃用;参见 JEP 11830

  • 更改

    • 添加了 jax.pure_callback(),它使得可以从编译函数(例如用 jax.jitjax.pmap 装饰的函数)回调到纯 Python 函数。

  • 弃用:

    • 已弃用的 DeviceArray.tile() 方法已被移除。请使用 jax.numpy.tile() (#11944)。

    • DeviceArray.to_py() 已被弃用。请改用 np.asarray(x)

jax 0.3.16#

jax 0.3.15 (2022年7月22日)#

jaxlib 0.3.15 (2022年7月22日)#

jax 0.3.14 (2022年6月27日)#

  • GitHub 提交

  • 重大变更

    • jax.experimental.compilation_cache.initialize_cache() 不再支持 max_cache_size_ 字节 并且不会将其作为输入。

    • JAX_PLATFORMS 现在在平台初始化失败时会引发异常。

  • 更改

    • 修复了与 NumPy 1.23 的兼容性问题。

    • jax.numpy.linalg.slogdet() 现在接受一个可选的 method 参数,允许在基于LU分解的实现和基于QR分解的实现之间进行选择。

    • jax.numpy.linalg.qr() 现在支持 mode="raw"

    • pickle, copy.copy, 和 copy.deepcopy 现在在使用 JAX 数组时有了更完整的功能支持(#10659)。特别是:

      • pickledeepcopy 之前在使用 DeviceArray 时返回 np.ndarray 对象;现在返回 DeviceArray 对象。对于 deepcopy,复制的数组与原始数组位于同一设备上。对于 pickle,反序列化的数组将位于默认设备上。

      • 在函数变换中(即跟踪的代码),deepcopycopy 以前是空操作。现在它们使用与 DeviceArray.copy() 相同的机制。

      • 对跟踪的数组调用 pickle 现在会导致显式的 ConcretizationTypeError

    • 在TPU上,奇异值分解(SVD)和对称/厄米特矩阵的特征分解的实现应该会显著加快,特别是对于1000x1000或更大的矩阵。两者现在都使用谱分治算法进行特征分解(QDWH-eig)。

    • jax.numpy.ldexp() 不再静默地将所有输入提升为 float64,而是对于大小为 int32 或更小的整数输入,提升为 float32(#10921)。

    • jax.profiler.start_trace()jax.profiler.start_trace() 添加 create_perfetto_link 选项。使用时,分析器将生成一个链接到 Perfetto UI 以查看跟踪。

    • 更改了 jax.profiler.start_server(...)() 的语义,使其全局存储 keepalive,而不是要求用户保留对其的引用。

    • 添加了 jax.random.generalized_normal()

    • 添加了 jax.random.ball()

    • 添加了 jax.default_device()

    • 添加了一个 python -m jax.collect_profile 脚本,作为手动捕获程序跟踪的替代方法,而不是使用 TensorBoard UI。

    • 添加了一个 jax.named_scope 上下文管理器,它为 Python 程序添加了分析器元数据(类似于 jax.named_call)。

    • 在散布更新操作(即 :attr:jax.numpy.ndarray.at)中,不安全的隐式数据类型转换已被弃用,现在会导致 FutureWarning。在未来的版本中,这将变为一个错误。一个不安全的隐式转换的例子是 jnp.zeros(4, dtype=int).at[0].set(1.5),其中 1.5 之前会被静默截断为 1

    • jax.experimental.compilation_cache.initialize_cache() 现在支持将 gcs 存储桶路径作为输入。

    • 添加了 jax.scipy.stats.gennorm()

    • jax.numpy.roots() 现在在 strip_zeros=False 时表现更好,当系数有前导零时(#11215)。

jaxlib 0.3.14 (2022年6月27日)#

  • GitHub 提交

    • x86-64 Mac 轮子现在需要 Mac OS 10.14 (Mojave) 或更新版本。Mac OS 10.14 于2018年发布,因此这不应是一个非常苛刻的要求。

    • NCCL 的捆绑版本已更新至 2.12.12,修复了一些死锁问题。

    • Python flatbuffers 包不再是 jaxlib 的依赖项。

jax 0.3.13 (2022年5月16日)#

jax 0.3.12 (2022年5月15日)#

jax 0.3.11 (2022年5月15日)#

  • GitHub 提交

  • 更改

    • jax.lax.eigh() 现在接受一个可选的 sort_eigenvalues 参数,允许用户在TPU上选择退出特征值排序。

  • 弃用

    • jax.lax.linalg 中的函数非数组参数现在被标记为仅关键字。作为向后兼容的步骤,传递仅关键字参数的位置会产生警告,但在未来的 JAX 版本中,传递仅关键字参数的位置将会失败。然而,大多数用户应该更喜欢使用 jax.numpy.linalg

    • jax.scipy.linalg.polar_unitary(),这是对scipy API的JAX扩展,已被弃用。请改用jax.scipy.linalg.polar()

jax 0.3.10 (2022年5月3日)#

jaxlib 0.3.10 (2022年5月3日)#

  • GitHub 提交

  • 更改

    • TF commit 修复了MHLO规范化器中的一个问题,该问题导致某些程序的常量折叠耗时过长或崩溃。

jax 0.3.9 (2022年5月2日)#

  • GitHub 提交

  • 更改

    • 为 GlobalDeviceArray 添加了对完全异步检查点的支持。

jax 0.3.8 (2022年4月29日)#

  • GitHub 提交

  • 更改

    • jax.numpy.linalg.svd() 在 TPU 上使用 qdwh-svd 求解器。

    • jax.numpy.linalg.cond() 在 TPU 上现在接受复数输入。

    • jax.numpy.linalg.pinv() 现在在 TPU 上接受复数输入。

    • jax.numpy.linalg.matrix_rank() 现在在 TPU 上接受复数输入。

    • jax.scipy.cluster.vq.vq() 已添加。

    • jax.experimental.maps.mesh 已被删除。请使用 jax.experimental.maps.Mesh。更多信息请参见 https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh。

    • jax.scipy.linalg.qr() 现在在 mode='r' 时返回一个长度为1的元组,而不是原始数组,以匹配 scipy.linalg.qr 的行为(#10452

    • jax.numpy.take_along_axis() 现在接受一个可选的 mode 参数,该参数指定越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在JAX的早期版本中,无效索引会被限制在范围内。可以通过传递 mode="clip" 来恢复之前的行为。

    • jax.numpy.take() 现在默认使用 mode="fill",这会为越界索引返回无效值(例如,NaN)。

    • 散布操作,例如 x.at[...].set(...),现在具有 "drop" 语义。这不会影响散布操作本身,但它意味着在微分时,散布的梯度将为越界索引产生零余切。以前,越界索引在梯度计算时会被限制在范围内,这在数学上是不正确的。

    • jax.numpy.take_along_axis() 现在如果其索引不是整数类型,则会引发 TypeError,这与 numpy.take_along_axis() 的行为相匹配。以前非整数索引会被静默转换为整数。

    • jax.numpy.ravel_multi_index() 现在如果其 dims 参数不是整数类型,则会引发 TypeError,这与 numpy.ravel_multi_index() 的行为相匹配。以前非整数的 dims 会被静默转换为整数。

    • jax.numpy.split() 现在如果其 axis 参数不是整数类型,则会引发 TypeError,这与 numpy.split() 的行为相匹配。以前非整数的 axis 会被静默转换为整数。

    • jax.numpy.indices() 现在如果其维度不是整数类型,则会引发 TypeError,这与 numpy.indices() 的行为相匹配。以前非整数维度会被静默转换为整数。

    • jax.numpy.diag() 现在如果其 k 参数不是整数类型,则会引发 TypeError,这与 numpy.diag() 的行为相匹配。以前非整数的 k 会被静默转换为整数。

    • 添加了 jax.random.orthogonal()

  • 弃用

    • jax.test_util 中可用的许多函数和对象现在已被弃用,并且在导入时会引发警告。这包括 cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, 和 _default_tolerance (#10389)。这些,连同之前已弃用的 JaxTestCase, JaxTestLoader, 和 BufferDonationTestCase,将在未来的 JAX 版本中被移除。这些工具中的大多数可以通过调用标准 python & numpy 测试工具来替代,例如 unittest, absl.testing, numpy.testing 等。JAX 特定的功能,如设备检查,可以通过使用公共 API 如 jax.devices() 来替代。许多已弃用的工具仍将存在于 jax._src.test_util 中,但这些不是公共 API,因此可能在未来的版本中被更改或移除,恕不另行通知。

jax 0.3.7 (2022年4月15日)#

jaxlib 0.3.7 (2022年4月15日)#

  • 更改:

    • Linux 轮子现在按照 manylinux2014 标准构建,而不是 manylinux2010

jax 0.3.6 (2022年4月12日)#

  • GitHub 提交

  • 更改:

    • 升级了 libtpu 轮子到修复了初始化 TPU pod 时挂起问题的版本。修复了 #10218

  • 弃用:

    • jax.experimental.loops 已被弃用。请参阅 #10278 获取替代 API。

jax 0.3.5 (2022年4月7日)#

jaxlib 0.3.5 (2022年4月7日)#

  • 错误修复

    • 修复了一个错误,即双精度复数到实数的 IRFFTs 会在 GPU 上改变其输入缓冲区(#9946)。

    • 修复了复杂散射的常量折叠错误(#10159

jax 0.3.4 (2022年3月18日)#

jax 0.3.3 (2022年3月17日)#

jax 0.3.2 (2022年3月16日)#

  • GitHub 提交

  • 更改:

    • 函数 jax.ops.index_update, jax.ops.index_add 在 0.2.22 版本中已被弃用,现已被移除。请改用 JAX 数组的 .at 属性,例如 x.at[idx].set(y)

    • jax.experimental.ann.approx_*_k 移动到 jax.lax 中。这些函数是 jax.lax.top_k 的优化替代方案。

    • jax.numpy.broadcast_arrays()jax.numpy.broadcast_to() 现在需要标量或类数组输入,如果传递列表则会失败(属于 #7737 的一部分)。

    • 标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。

    • pjit 现在可以在 CPU 上运行(除了之前的 TPU 和 GPU 支持外)。

jaxlib 0.3.2 (2022年3月16日)#

  • 更改

    • XlaComputation.as_hlo_text() 现在通过传递布尔标志 print_large_constants=True 支持打印大常量。

  • 弃用:

    • JAX 数组上的 .block_host_until_ready() 方法已被弃用。请改用 .block_until_ready()

jax 0.3.1 (2022年2月18日)#

jax 0.3.0 (2022年2月10日)#

jaxlib 0.3.0 (2022年2月10日)#

  • 更改

    • 构建 jaxlib 现在需要 Bazel 5.0.0。

    • jaxlib 版本已更新至 0.3.0。请参阅 设计文档 以了解解释。

jax 0.2.28 (2022年2月1日)#

  • GitHub 提交

    • jax.jit(f).lower(...).compiler_ir() 现在默认使用 MHLO 方言,如果没有传递 dialect= 参数。

    • jax.jit(f).lower(...).compiler_ir(dialect='mhlo') 现在返回一个 MLIR ir.Module 对象,而不是它的字符串表示。

jaxlib 0.1.76 (2022年1月27日)#

  • 新功能

    • 包含为NVidia计算能力8.0的GPU(例如A100)预编译的SASS。移除计算能力6.1的预编译SASS,以避免增加计算能力的数量:计算能力6.1的GPU可以使用6.0的SASS。

    • 在 jaxlib 0.1.76 版本中,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。

  • 重大变更

    • 根据 弃用政策,已不再支持 NumPy 1.18。请升级到支持的 NumPy 版本。

  • 错误修复

    • 修复了一个错误,即通过不同路径构造的明显相同的 pytreedef 对象不会被比较为相等 (#9066)。

    • JAX 的 jit 缓存要求两个静态参数具有相同的类型才能命中缓存 (#9311)。

jax 0.2.27 (2022年1月18日)#

  • GitHub 提交

  • 重大变更:

    • 根据 弃用政策,已不再支持 NumPy 1.18。请升级到支持的 NumPy 版本。

    • host_callback 原语已简化,删除了 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只有原始值会被触发。旧的行为可以通过设置 JAX_HOST_CALLBACK_AD_TRANSFORMS 环境变量或 --jax_host_callback_ad_transforms 标志来获得(有限时间内)。此外,增加了如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。

    • 排序现在与 NumPy 的行为一致,无论位表示如何,都匹配 0.0NaN。特别是,0.0-0.0 现在被视为等价,而之前 -0.0 被视为小于 0.0。此外,所有 NaN 表示现在都被视为等价,并排序到数组的末尾。以前,负 NaN 值被排序到数组的前面,并且具有不同内部位表示的 NaN 值不被视为等价,而是根据这些位模式进行排序(#9178)。

    • jax.numpy.unique() 现在以与 np.unique 在 NumPy 1.21 及更高版本中相同的方式处理 NaN 值:在唯一化的输出中最多会出现一个 NaN 值(#9184)。

  • 错误修复:

    • host_callback 现在支持 ad_checkpoint.checkpoint (#8907)。

  • 新功能:

    • 添加 jax.block_until_ready ({jax-issue}`#8941)

    • 添加了一个新的调试标志/环境变量 JAX_DUMP_IR_TO=/path。如果设置,JAX 会将它为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件中。

    • jax.ensure_compile_time_eval 添加到公共 API 中(#7987)。

    • jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions,以改变关联归约(例如,jnp.cumsum)的降低行为,使其在 CPU 和 GPU 上表现得像 JAX(使用关联扫描)。更多详情请参见 jax2tf README(#9189)。

jaxlib 0.1.75 (2021年12月8日)#

  • 新功能:

    • 对 Python 3.10 的支持

jax 0.2.26 (2021年12月8日)#

  • GitHub 提交

  • 错误修复:

    • 超出边界的索引在 jax.ops.segment_sum 中现在将按照文档所述的处理方式,使用 FILL_OR_DROP 语义。这主要影响反向模式导数,其中对应于超出边界索引的梯度现在将返回为 0。(#8634)。

    • jax2tf 将强制转换后的代码在 jax.jit 下的代码片段使用 XLA,例如,大多数 jax.numpy 函数 (#7839)。

jaxlib 0.1.74 (2021年11月17日)#

  • 启用了GPU之间的点对点复制。以前,GPU复制通过主机进行,这通常较慢。

  • 为 JAX 添加了实验性的 MLIR Python 绑定。

jax 0.2.25 (2021年11月10日)#

  • GitHub 提交

  • 新功能:

    • (实验性) jax.distributed.initialize 暴露了多主机GPU后端。

    • jax.random.permutation 支持新的 independent 关键字参数 (#8430)

  • 重大变更

    • jax.experimental.stax 移动到 jax.example_libraries.stax

    • jax.experimental.optimizers 移动到 jax.example_libraries.optimizers

  • 新功能:

    • 添加了 jax.lax.linalg.qdwh

jax 0.2.24 (2021年10月19日)#

  • GitHub 提交

  • 新功能:

    • jax.random.choicejax.random.permutation 现在支持多维数组和一个可选的 axis 参数 (#8158)

  • 重大变更:

    • jax.numpy.takejax.numpy.take_along_axis 现在需要类似数组的输入(参见 #7737

jaxlib 0.1.73 (2021年10月18日)#

  • jaxlib GPU cuda11 轮子现在支持多个 cuDNN 版本。

    • cuDNN 8.2 或更新版本。如果你的 cuDNN 安装足够新,我们推荐使用 cuDNN 8.2 的 wheel 文件,因为它支持额外的功能。

    • cuDNN 8.0.5 或更新版本。

  • 重大变更:

    • GPU jaxlib 的安装命令如下:

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      

jax 0.2.22 (2021年10月12日)#

  • GitHub 提交

  • 重大变更

    • 传递给 jax.pmap 的静态参数现在必须是可哈希的。

      不可哈希的静态参数在 jax.jit 上早已被禁止,但它们在 jax.pmap 上仍然被允许;jax.pmap 通过对象标识来比较不可哈希的静态参数。

      这种行为是一个潜在的危险,因为通过对象标识进行参数比较会导致每次对象标识改变时都需要重新编译。相反,我们现在禁止使用不可哈希的参数:如果 jax.pmap 的用户希望通过对象标识比较静态参数,他们可以在其对象上定义 __hash____eq__ 方法来实现这一点,或者将他们的对象包装在一个具有对象标识语义的这些操作的对象中。另一个选择是使用 functools.partial 将不可哈希的静态参数封装到函数对象中。

    • jax.util.partial 是一个意外的导出,现在已经被移除。请改用 Python 标准库中的 functools.partial

  • 弃用

    • 函数 jax.ops.index_update, jax.ops.index_add 等已被弃用,并将在未来的 JAX 版本中移除。请改用 JAX 数组的 .at 属性,例如 x.at[idx].set(y)。目前,这些函数会生成一个 DeprecationWarning

  • 新功能:

    • 在使用 jaxlib 0.1.72 或更新版本时,优化后的 C++ 代码路径现在默认用于改进 pmap 的调度时间。该功能可以通过 --experimental_cpp_pmap 标志(或 JAX_CPP_PMAP 环境变量)来禁用。

    • jax.numpy.unique 现在支持一个可选的 fill_value 参数 (#8121)

jaxlib 0.1.72 (2021年10月12日)#

  • 重大变更:

    • 已不再支持 CUDA 10.2 和 CUDA 10.1。Jaxlib 现在支持 CUDA 11.1+。

  • 错误修复:

    • 修复了 https://github.com/google/jax/issues/7461,该问题导致所有平台上的输出错误,原因是 XLA 编译器内部缓冲区别名不正确。

jax 0.2.21 (2021年9月23日)#

  • GitHub 提交

  • 重大变更

    • jax.api 已被移除。原本作为 jax.api.* 提供的函数是 jax.* 中函数的别名;请改用 jax.* 中的函数。

    • jax.partialjax.lax.partial 是意外导出的内容,现已被移除。请改用 Python 标准库中的 functools.partial

    • 布尔标量索引现在会引发 TypeError;之前这会静默返回错误结果 (#7925)。

    • 现在许多 jax.numpy 函数需要类似数组的输入,如果传递一个列表将会报错(#7747 #7802 #7907)。有关此更改背后原因的讨论,请参见 #7737

    • 在诸如 jax.jit 的转换中,jax.numpy.array 总是将其生成的数组编排到跟踪的计算中。以前,即使在 jax.jit 装饰器下,jax.numpy.array 有时也会生成一个设备上的数组。这一更改可能会破坏使用 JAX 数组执行形状或索引计算的代码,这些计算必须是静态已知的;解决方法是使用经典的 NumPy 数组来执行此类计算。

    • jnp.ndarray 现在是 JAX 数组的真正基类。特别是,这意味着对于标准的 numpy 数组 xisinstance(x, jnp.ndarray) 现在将返回 False (#7927)。

  • 新功能:

jax 0.2.20 (2021年9月2日)#

  • GitHub 提交

  • 重大变更

    • jnp.poly* 函数现在需要类似数组的输入 (#7732)

    • jnp.unique 和其他类似集合的操作现在需要类似数组的输入 (#7662)

jaxlib 0.1.71 (2021年9月1日)#

  • 重大变更:

    • 已不再支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。

jax 0.2.19 (2021年8月12日)#

  • GitHub 提交

  • 重大变更:

    • 根据 弃用政策,NumPy 1.17 的支持已被取消。请升级到受支持的 NumPy 版本。

    • jit 装饰器已添加到 JAX 数组上多个操作符的实现中。这加快了常见操作符(如 +)的调度时间。

      这一更改对大多数用户来说应该基本上是透明的。然而,有一个已知的行为变化,即当直接传递给JAX操作符时,大整数常量现在可能会产生错误(例如,x + 2**40)。解决方法是将常量转换为显式类型(例如,np.float64(2**40))。

  • 新功能:

    • 改进了 jax2tf 中形状多态性的支持,对于需要在数组计算中使用维度大小的操作,例如 jnp.mean。(#7317)

  • 错误修复:

    • 一些从上一个版本泄露的跟踪错误(#7613

jaxlib 0.1.70 (2021年8月9日)#

  • 重大变更:

    • 根据 弃用政策,Python 3.6 的支持已被取消。请升级到受支持的 Python 版本。

    • 根据 弃用政策,NumPy 1.17 的支持已被取消。请升级到受支持的 NumPy 版本。

    • host_callback 机制现在为每个本地设备使用一个线程来调用 Python 回调。以前所有设备共用一个线程。这意味着回调现在可能是交错调用的。对应于一个设备的回调仍将按顺序调用。

jax 0.2.18 (2021年7月21日)#

  • GitHub 提交

  • 重大变更:

    • 根据 弃用政策,Python 3.6 的支持已被取消。请升级到受支持的 Python 版本。

    • jaxlib 的最低版本现在是 0.1.69。

    • 已移除 jax.dlpack.from_dlpack()backend 参数。

  • 新功能:

  • 错误修复:

    • 加强了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的 axis 值一起使用,也不会与空的缩减维度一起使用。(#7196)

jaxlib 0.1.69 (2021年7月9日)#

  • 修复TFRT CPU后端中的错误,这些错误导致结果不正确。

jax 0.2.17 (2021年7月9日)#

  • GitHub 提交

  • 错误修复:

    • 对于 jaxlib <= 0.1.68,默认使用较旧的“stream_executor”CPU 运行时以解决 #7229 问题,该问题由于并发问题导致 CPU 输出错误。

  • 新功能:

jax 0.2.16 (2021年6月23日)#

jax 0.2.15 (2021年6月23日)#

  • GitHub 提交

  • 新功能:

    • #7042 在CPU上启用了TFRT CPU后端,显著提升了CPU的调度性能。

    • The jax2tf.convert() 支持布尔值的不等式和最小/最大值(#6956)。

    • 新 SciPy 函数 jax.scipy.special.lpmn_values()

  • 重大变更:

  • 错误修复:

    • 修复了阻止从JAX到TF再返回的往返问题:jax2tf.call_tf(jax2tf.convert) (#6947)。

jaxlib 0.1.68 (2021年6月23日)#

  • 错误修复:

    • 修复了TFRT CPU后端在将TPU缓冲区传输到CPU时产生NaN的错误。

jax 0.2.14 (2021年6月10日)#

  • GitHub 提交

  • 新功能:

    • 现在,jax2tf.convert() 支持 pjitsharded_jit

    • 一个新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。

    • 在足够新的 IPython 版本中,使用 __tracebackhide__ 的新回溯过滤模式现在默认启用。

    • 即使在使用算术运算中的未知维度时,jax2tf.convert() 也支持形状多态性,例如 jnp.reshape(-1) (#6827)。

    • The jax2tf.convert() 在 TF 操作中生成带有位置信息的自定义属性。jax2tf 之后生成的 XLA 代码与 JAX/XLA 具有相同的位置信息。

    • 新 SciPy 函数 jax.scipy.special.lpmn()

  • 错误修复:

    • 现在,jax2tf.convert() 确保它使用与 JAX 相同的类型规则来处理 Python 标量以及选择 32 位与 64 位计算(#6883)。

    • 现在,jax2tf.convert() 正确地将 enable_xla 转换参数的作用域限定在即时转换期间(#6720)。

    • 现在,jax2tf.convert() 使用 XlaDot TensorFlow 操作来转换 lax.dot_general,以更好地保持与 JAX 数值精度的一致性(#6717)。

    • 现在,jax2tf.convert() 支持复数的非等式比较和最小/最大值计算(#6892)。

jaxlib 0.1.67 (2021年5月17日)#

jaxlib 0.1.66 (2021年5月11日)#

  • 新功能:

    • CUDA 11.1 轮子现在支持所有 CUDA 11 版本 11.1 或更高版本。

      Nvidia 现在承诺从 CUDA 11.1 开始,CUDA 次要版本之间的兼容性。这意味着 JAX 可以发布一个与 CUDA 11.2 和 11.3 兼容的单一 CUDA 11.1 轮子。

      不再有针对 CUDA 11.2(或更高版本)的单独 jaxlib 发布;请为这些版本使用 CUDA 11.1 的轮子(cuda111)。

    • Jaxlib 现在在 CUDA 轮子中捆绑了 libdevice.10.bc。应该不需要指向 JAX 的 CUDA 安装来找到这个文件。

    • jit() 实现中添加了对静态关键字参数的自动支持。

    • 增加了对预转换异常跟踪的支持。

    • 初步支持从 jit() 转换的计算中修剪未使用的参数。修剪工作仍在进行中。

    • 改进了 PyTreeDef 对象的字符串表示。

    • 增加了对 XLA 的可变 ReduceWindow 的支持。

  • 错误修复:

    • 修复了在向计算传递大量参数时远程云TPU支持中的一个错误。

    • 修复了一个错误,该错误意味着由 jit() 转换的函数不会触发 JAX 垃圾回收。

jax 0.2.13 (2021年5月3日)#

  • GitHub 提交

  • 新功能:

    • 当与 jaxlib 0.1.66 结合使用时,jax.jit() 现在支持静态关键字参数。新增了 static_argnames 选项,用于指定关键字参数为静态。

    • jax.nonzero() 有一个新的可选 size 参数,允许它在 jit 中使用(#6501

    • jax.numpy.unique() 现在支持 axis 参数 (#6532)。

    • jax.experimental.host_callback.call() 现在支持 pjit.pjit (#6569)。

    • 添加了 jax.scipy.linalg.eigh_tridiagonal(),用于计算三对角矩阵的特征值。目前仅支持特征值计算。

    • 异常中过滤和未过滤的堆栈跟踪顺序已更改。从 JAX 转换代码抛出的异常所附带的回溯现在已被过滤,包含原始跟踪的 UnfilteredStackTrace 异常作为过滤异常的 __cause__。过滤后的堆栈跟踪现在也适用于 Python 3.6。

    • 如果由反向模式自动微分转换的代码抛出异常,JAX现在尝试将一个包含正向传递中创建原始操作的堆栈跟踪的 JaxStackTraceBeforeTransformation 对象附加为异常的 __cause__。需要 jaxlib 0.1.66。

  • 重大变更:

    • 以下函数名称已更改。仍然存在别名,因此这不应破坏现有代码,但别名最终将被移除,因此请更改您的代码。

    • 同样地,local_devices() 的参数已从 host_id 重命名为 process_index

    • 除了函数之外,传递给 jax.jit() 的参数现在被标记为仅关键字。这一更改是为了防止在向 jit 添加参数时发生意外破坏。

  • 错误修复:

    • 现在,jax2tf.convert() 在存在整数输入函数的梯度时也能正常工作(#6360)。

    • 修复了在使用捕获的 tf.Variablejax2tf.call_tf() 中的断言失败问题(#6572)。

jaxlib 0.1.65 (2021年4月7日)#

jax 0.2.12 (2021年4月1日)#

  • GitHub 提交

  • 新功能

  • 重大变更:

    • jaxlib 的最低版本现在是 0.1.64。

    • 一些性能分析器API的名称已经更改。仍然存在别名,因此这不会破坏现有代码,但别名最终将被移除,因此请更改您的代码。

    • Omnistaging 不能再被禁用。更多信息请参见 omnistaging

    • 大于最大 int64 值的 Python 整数现在将在所有情况下导致溢出,而不是在某些情况下被静默转换为 uint64 (#6047)。

    • 在非 X64 模式下,超出 int32 表示范围的 Python 整数现在将导致 OverflowError,而不是静默地截断其值。

  • 错误修复:

    • host_callback 现在支持在参数和结果中使用空数组 (#6262)。

    • jax.random.randint() 会裁剪而不是包装越界限制,并且现在可以在指定数据类型的完整范围内生成整数(#5868

jax 0.2.11 (2021年3月23日)#

  • GitHub 提交

  • 新功能:

    • #6112 添加了上下文管理器:jax.enable_checksjax.check_tracer_leaksjax.debug_nansjax.debug_infsjax.log_compiles

    • #6085 添加了 jnp.delete

  • 错误修复:

    • #6136jax.flatten_util.ravel_pytree 扩展为处理整数数据类型。

    • #6129 修复了处理某些常量(如 enum.IntEnums)的错误

    • #6145 修复了不完整beta函数中的批处理问题

    • #6014 修复了跟踪期间的H2D传输

    • #6165 在将某些大Python整数转换为浮点数时避免了OverflowErrors

  • 重大变更:

    • 最低的 jaxlib 版本现在是 0.1.62。

jaxlib 0.1.64 (2021年3月18日)#

jaxlib 0.1.63 (2021年3月17日)#

jax 0.2.10 (2021年3月5日)#

  • GitHub 提交

  • 新功能:

    • jax.scipy.stats.chi2() 现在作为一个分布可用,具有 logpdf 和 pdf 方法。

    • jax.scipy.stats.betabinom() 现在作为一个分布可用,具有 logpmf 和 pmf 方法。

    • 添加了 jax.experimental.jax2tf.call_tf() 以从 JAX 调用 TensorFlow 函数(#5627)和 README)。

    • 扩展了 lax.pad 的批处理规则,以支持填充值的批处理。

  • 错误修复:

  • 重大变更:

    • JAX 的提升规则进行了调整,以使提升更加一致且不受 JIT 影响。特别是,二元操作现在在适当的情况下可以产生弱类型值。这一变化的主要用户可见效果是,某些操作的结果精度与之前不同;例如表达式 jnp.bfloat16(1) + 0.1 * jnp.arange(10) 之前返回一个 float64 数组,现在返回一个 bfloat16 数组。JAX 的类型提升行为在 类型提升 中描述。

    • jax.numpy.linspace() 现在计算整数值的向下取整,即向 -inf 而不是 0 取整。这一更改是为了与 NumPy 1.20.0 匹配。

    • jax.numpy.i0() 不再接受复数。以前,该函数计算复数参数的绝对值。此更改是为了匹配 NumPy 1.20.0 的语义。

    • 几个 jax.numpy 函数不再接受元组或列表作为数组参数:jax.numpy.pad()、:funcjax.numpy.raveljax.numpy.repeat()jax.numpy.reshape()。通常,jax.numpy 函数应与标量或数组参数一起使用。

jaxlib 0.1.62 (2021年3月9日)#

  • 新功能:

    • jaxlib 轮子现在默认构建为在 x86-64 机器上需要 AVX 指令。如果你想在不支持 AVX 的机器上使用 JAX,你可以使用 --target_cpu_features 标志通过 build.py 从源代码构建 jaxlib。--target_cpu_features 也替代了 --enable_march_native

jaxlib 0.1.61 (2021年2月12日)#

jaxlib 0.1.60 (2021年2月3日)#

  • 错误修复:

    • 修复了将CPU DeviceArrays转换为NumPy数组时的内存泄漏问题。该内存泄漏问题存在于jaxlib版本0.1.58和0.1.59中。

    • bool, int8, 和 uint8 现在被认为是安全的,可以转换为 bfloat16 NumPy 扩展类型。

jax 0.2.9 (2021年1月26日)#

jaxlib 0.1.59 (2021年1月15日)#

jax 0.2.8 (2021年1月12日)#

  • GitHub 提交

  • 新功能:

  • 错误修复:

    • jax.numpy.arccosh 现在对于复数输入返回与 numpy.arccosh 相同的分支 (#5156)

    • host_callback.id_tap 现在也适用于 jax.pmapid_tapid_print 有一个可选参数,用于请求将从中获取值的设备作为关键字参数传递给 tap 函数(#5182)。

  • 重大变更:

  • 新功能:

    • 用于调试 inf 的新标志,类似于用于 NaN 的标志(#5224)。

jax 0.2.7 (2020年12月4日)#

  • GitHub 提交

  • 新功能:

    • 添加 jax.device_put_replicated

    • jax.experimental.sharded_jit 添加多主机支持

    • 添加对 jax.numpy.linalg.eig 计算的特征值进行区分支持

    • 添加对在Windows平台上构建的支持

    • jax.pmap 中添加对一般 in_axes 和 out_axes 的支持

    • jax.numpy.linalg.slogdet 添加复杂支持

  • 错误修复:

    • 修复 jax.numpy.sinc 在零点的二阶以上高阶导数

    • 修复了转置规则中关于符号零的一些难以触及的错误

  • 重大变更:

    • jax.experimental.optix 已被删除,取而代之的是独立的 optax Python 包。

    • 使用非元组序列索引 JAX 数组现在会引发 TypeError。这种索引方式自 Numpy v1.16 和 JAX v0.2.4 以来已被弃用。请参见 #4564

jax 0.2.6 (2020年11月18日)#

  • GitHub 提交

  • 新功能:

    • 为 jax.experimental.jax2tf 转换器添加对形状多态跟踪的支持。参见 README.md

  • 重大变更清理

    • 在 jax.jit 和 xla_computation 中对不可哈希的静态参数引发错误。参见 cb48f42

    • 改进类型提升行为的一致性(#4744):

      • 将一个复杂的Python标量添加到JAX浮点数时,会尊重JAX浮点的精度。例如,jnp.float32(1) + 1j 现在返回 complex64,而之前返回的是 complex128

      • 包含 uint64、有符号整数和第三种类型的三个或更多项的类型提升结果现在与参数顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) 都返回 float16,而之前第一个返回 float64,第二个返回 float16

    • jax.lax_linalg 线性代数模块的内容(未记录)现在公开为 jax.lax.linalg

    • jax.random.PRNGKey 现在在 JIT 编译内外产生相同的结果(#4877)。这需要在少数特定情况下更改给定种子的结果:

      • 使用 jax_enable_x64=False,作为 Python 整数传递的负种子在 JIT 模式外现在返回不同的结果。例如,jax.random.PRNGKey(-1) 之前返回 [4294967295, 4294967295],现在返回 [0, 4294967295]。这与 JIT 中的行为相匹配。

      • 超出 int64 表示范围的种子现在在非JIT环境中会导致 OverflowError 而不是 TypeError。这与JIT中的行为一致。

      要在 JIT 外部恢复之前为负整数返回的键,并且 jax_enable_x64=False,您可以使用:

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • 当试图访问已被删除的值时,DeviceArray 现在会引发 RuntimeError 而不是 ValueError

jaxlib 0.1.58 (2021年1月12日左右)#

  • 修复了一个错误,该错误导致 JAX 有时返回特定于平台的类型(例如,np.cint)而不是标准类型(例如,np.int32)。(#4903)

  • 修复了在常量折叠某些 int16 操作时发生的崩溃。(#4971)

  • pytree.flatten() 添加了 is_leaf 谓词。

jaxlib 0.1.57 (2020年11月12日)#

  • 修复了GPU轮子中的许多linux2010合规性问题。

  • 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。

  • 修复了一个错误,该错误导致 bfloat16 值的哈希未正确初始化,可能会发生变化 (#4651)。

  • 添加在将数组传递给 DLPack 时保留所有权的支持 (#4636)。

  • 修复了批量三角求解中大小大于128但不是128倍数的错误。

  • 修复了在多个GPU上执行并发FFT时的错误 (#3518)。

  • 修复了分析器中工具缺失的错误 (#4427)。

  • 已放弃对 CUDA 10.0 的支持。

jax 0.2.5 (2020年10月27日)#

jax 0.2.4 (2020年10月19日)#

  • GitHub 提交

  • 改进:

    • jax.experimental.host_callback 添加对 remat 的支持。参见 #4608

  • 弃用

    • 使用非元组序列进行索引现在已被弃用,这与Numpy中的类似弃用一致。在未来的版本中,这将导致TypeError。请参阅 #4564

jaxlib 0.1.56 (2020年10月14日)#

jax 0.2.3 (2020年10月14日)#

  • GitHub 提交

  • 这么快发布另一个版本的原因是,我们需要暂时回滚一个新的 jit 快速路径,同时我们正在调查性能下降的问题。

jax 0.2.2 (2020年10月13日)#

jax 0.2.1 (2020年10月6日)#

jax (0.2.0) (2020年9月23日)#

jax (0.1.77) (2020年9月15日)#

jaxlib 0.1.55 (2020年9月8日)#

  • 更新 XLA:

    • 修复 DLPackManagedTensorToBuffer 中的错误 (#4196)

jax 0.1.76 (2020年9月8日)#

jax 0.1.75 (2020年7月30日)#

  • GitHub 提交

  • Bug 修复:

    • 使 jnp.abs() 适用于无符号输入 (#3914)

  • 改进:

    • 在标志后添加了“Omnistaging”行为,默认情况下禁用 (#3370)

jax 0.1.74 (2020年7月29日)#

  • GitHub 提交

  • 新功能:

    • BFGS (#3101)

    • TPU 支持半精度算术 (#3878)

  • Bug 修复:

    • 防止一些意外的 dtype 警告 (#3874)

    • 修复自定义导数中的多线程错误(#3845,#3869)

  • 改进:

    • 更快的 searchsorted 实现 (#3873)

    • 为 jax.numpy 排序算法提供更好的测试覆盖率 (#3836)

jaxlib 0.1.52 (2020年7月22日)#

  • 更新 XLA。

jax 0.1.73 (2020年7月22日)#

  • GitHub 提交

  • jaxlib 的最低版本现在是 0.1.51。

  • 新功能:

    • jax.image.resize. (#3703)

    • hfft 和 ihfft (#3664)

    • jax.numpy.intersect1d (#3726)

    • jax.numpy.lexsort (#3812)

    • lax.scanscan 原语在降低到 XLA 时支持一个用于循环展开的 unroll 参数(#3738)。

  • Bug 修复:

    • 修复重复轴的缩减错误 (#3618)

    • 修复了lax.pad 对于大小为0的输入维度的形状规则。(#3608)

    • make psum 转置处理零余切向量 (#3653)

    • 修复在大小为0的轴上进行reduce-prod的JVP时出现的形状错误。(#3729)

    • 通过 jax.lax.all_to_all 支持差异化 (#3733)

    • 在 jax.scipy.special.zeta 中解决 nan 问题 (#3777)

  • 改进:

    • jax2tf 的许多改进

    • 使用单次遍历的可变归约重新实现 argmin/argmax。(#3611)

    • 默认启用XLA SPMD分区。(#3151)

    • 添加对0维转置卷积的支持(#3643)

    • 使 LU 梯度适用于低秩矩阵 (#3610)

    • 在 jet 中支持 multiple_results 和自定义 JVPs (#3657)

    • 将 reduce-window 填充泛化为支持 (lo, hi) 对。(#3728)

    • 在CPU和GPU上实现复杂的卷积。(#3735)

    • 使 jnp.take 对空数组的空切片起作用。(#3751)

    • 放宽 dot_general 的维度排序规则。(#3778)

    • 为GPU启用缓冲区捐赠。(#3800)

    • 添加对基础扩张和窗口扩张的支持以减少窗口操作… (#3803)

jaxlib 0.1.51 (2020年7月2日)#

  • 更新 XLA。

  • 为 host_callback 添加新的运行时支持。

jax 0.1.72 (2020年6月28日)#

  • GitHub 提交

  • 错误修复:

    • 修复了在上一个版本中引入的 odeint 错误,参见 #3587

jax 0.1.71 (2020年6月25日)#

  • GitHub 提交

  • 最小 jaxlib 版本现在是 0.1.48。

  • 错误修复:

    • 允许 jax.experimental.ode.odeint 动力学函数闭包化相对于我们正在微分的值 #3562

jaxlib 0.1.50 (2020年6月25日)#

  • 添加对 CUDA 11.0 的支持。

  • 放弃对 CUDA 9.2 的支持(我们仅维护对最近四个 CUDA 版本的支持。)

  • 更新 XLA。

jaxlib 0.1.49 (2020年6月19日)#

jaxlib 0.1.48 (2020年6月12日)#

  • 新功能:

    • 添加对快速回溯收集的支持。

    • 增加了对设备上堆分析的初步支持。

    • bfloat16 类型实现了 np.nextafter

    • CPU 和 GPU 上的 FFT 支持 Complex128 类型。

  • 错误修复:

    • 提高了GPU上float64 tanh 的精度。

    • float64 在 GPU 上的散布速度要快得多。

    • 在CPU上进行复杂矩阵乘法应该会快得多。

    • 在CPU上的稳定排序现在应该是真正稳定的。

    • CPU 后端中的并发错误修复。

jax 0.1.70 (2020年6月8日)#

  • GitHub 提交

  • 新功能:

    • lax.switch 引入了带有多个分支的索引条件,以及对 cond 原语的泛化 #3318

jax 0.1.69 (2020年6月3日)#

jax 0.1.68 (2020年5月21日)#

jax 0.1.67 (2020年5月12日)#

  • GitHub 提交

  • 新功能:

    • 支持使用 axis_index_groups #2382 对 pmapped 轴的子集进行缩减。

    • 实验性支持从编译代码中打印和调用主机端Python函数。参见 id_print 和 id_tap (#3006)。

  • 显著变化:

    • jax.numpy 导出的名称的可见性已经收紧。这可能会破坏之前意外使用这些名称的代码。

jaxlib 0.1.47 (2020年5月8日)#

  • 修复了出料口的崩溃问题。

jax 0.1.66 (2020年5月5日)#

jaxlib 0.1.46 (2020年5月5日)#

  • 修复了在 Mac OS X 上使用线性代数函数时的崩溃问题 (#432)。

  • 修复了由于在操作系统或虚拟机管理程序禁用AVX512指令时使用AVX512指令导致的非法指令崩溃问题 (#2906)。

jax 0.1.65 (2020年4月30日)#

  • GitHub 提交

  • 新功能:

    • 奇异矩阵行列式的微分 #2809

  • 错误修复:

    • 修复 odeint() 对具有时间依赖动力学的 ODE 的时间微分 #2817,同时添加 ODE CI 测试。

    • 修复 lax_linalg.qr() 的微分问题 #2867

jaxlib 0.1.45 (2020年4月21日)#

  • 修复段错误: #2755

  • 在排序HLO时,通过Python实现对is_stable选项的全面支持。

jax 0.1.64 (2020年4月21日)#

jaxlib 0.1.44 (2020年4月16日)#

  • 修复了一个错误,即如果存在多个不同型号的GPU,JAX只会编译适合第一个GPU的程序。

  • 修复了 batch_group_count 卷积的错误。

  • 增加了预编译的 SASS 以支持更多 GPU 版本,以避免启动时的 PTX 编译挂起。

jax 0.1.63 (2020年4月12日)#

  • GitHub 提交

  • 添加了来自 #2026jax.custom_jvpjax.custom_vjp,请参阅 教程笔记本。已弃用 jax.custom_transforms 并从文档中移除(尽管它仍然有效)。

  • 添加 scipy.sparse.linalg.cg #2566

  • 更改了 Tracers 的打印方式,以显示更多对调试有用的信息 #2591

  • 使 jax.numpy.isclose 正确处理 naninf #2501

  • jax.experimental.jet 添加了几条新规则 #2537

  • 修复了当未提供 scale/centerjax.experimental.stax.BatchNorm 的问题。

  • 修复了 jax.numpy.einsum 中广播的一些缺失情况 #2512

  • 在并行前缀扫描 #2596 的基础上实现 jax.numpy.cumsumjax.numpy.cumprod,并使 reduce_prod 可微分至任意阶 #2597

  • batch_group_count 添加到 conv_general_dilated #2635

  • test_util.check_grads 添加文档字符串 #2656

  • 添加 callback_transform #2665

  • 实现 rollaxisconvolve/correlate 一维和二维、copysigntruncroots 以及 quantile/percentile 插值选项。

jaxlib 0.1.43 (2020年3月31日)#

  • 修复了GPU上Resnet-50的性能退化问题。

jax 0.1.62 (2020年3月21日)#

  • GitHub 提交

  • JAX 已停止对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。

  • 移除了内部函数 lax._safe_mul,该函数实现了约定 0. * nan == 0.。这一更改意味着在某些程序进行微分时,当它们之前产生正确值时,现在会产生 nans,尽管它确保了对于其他程序,产生的是 nans 而不是静默的错误结果。详情请参见 #2447 和 #1052。

  • 添加了一个 all_gather 并行便利函数。

  • 核心代码中更多的类型注解。

jaxlib 0.1.42 (2020年3月19日)#

  • jaxlib 0.1.41 因API不兼容导致云TPU支持中断。此版本再次修复了该问题。

  • JAX 已停止对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。

jax 0.1.61 (2020年3月17日)#

  • GitHub 提交

  • 修复了 Python 3.5 的支持。这将是最后一个支持 Python 3.5 的 JAX 或 jaxlib 版本。

jax 0.1.60 (2020年3月17日)#

  • GitHub 提交

  • 新功能:

    • jax.pmap() 有一个 static_broadcast_argnums 参数,允许用户指定应作为编译时常量处理并广播到所有设备的参数。它的工作原理类似于 jax.jit() 中的 static_argnums

    • 改进了当跟踪器错误地保存在全局状态时的错误信息。

    • 添加了 jax.nn.one_hot() 实用函数。

    • 添加了 jax.experimental.jet 以实现指数级更快的更高阶自动微分。

    • jax.lax.broadcast_in_dim() 的参数增加了更多的正确性检查。

  • 最低的 jaxlib 版本现在是 0.1.41。

jaxlib 0.1.40 (2020年3月4日)#

  • 在 Jaxlib 中添加了对 TensorFlow 分析器的实验性支持,这允许从 TensorBoard 追踪 CPU 和 GPU 计算。

  • 包含通过 NCCL 进行通信的多主机 GPU 计算的原型支持。

  • 提升GPU上NCCL集合操作的性能。

  • 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 的实现。

  • 支持在XLA编译时已知的设备分配。

jax 0.1.59 (2020年2月11日)#

  • GitHub 提交

  • 重大变更

    • jaxlib 的最低版本现在是 0.1.38。

    • 通过移除 Jaxpr.freevarsJaxpr.bound_subjaxprs 来简化 Jaxpr。调用原语(xla_callxla_pmapsharded_callremat_call)获得了一个新的参数 call_jaxpr,它包含一个完全封闭(无 constvars)的 jaxpr。此外,为原语添加了一个新的字段 call_primitive

  • 新功能:

    • 反向模式自动微分(例如 grad)对 lax.cond 的支持,使其现在在两种模式下都可微分(#2091

    • JAX 现在支持 DLPack,这使得可以以零拷贝的方式与其他库(如 PyTorch)共享 CPU 和 GPU 数组。

    • JAX GPU DeviceArrays 现在支持 __cuda_array_interface__,这是另一个用于与其他库(如 CuPy 和 Numba)共享 GPU 数组的零拷贝协议。

    • JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间的零拷贝缓冲区共享。

    • 添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知的慢速测试。

jaxlib 0.1.39 (2020年2月11日)#

  • 更新 XLA。

jaxlib 0.1.38 (2020年1月29日)#

  • CUDA 9.0 不再受支持。

  • CUDA 10.2 轮子现在默认构建。

jax 0.1.58 (2020年1月28日)#

显著的错误修复#

  • 随着Python 3的升级,JAX不再依赖于fastcache,这将有助于安装。