我们对 Keras 3 的三个后端(TensorFlow、JAX、PyTorch)进行了基准测试,并与使用 TensorFlow 的 Keras 2 进行了比较。有关重现我们结果的代码和设置详细信息,请查阅 这里。
我们选择了一组流行的计算机视觉和自然语言处理模型,用于生成和非生成 AI 任务。请参见下表以查看我们的选择。
表 1:基准测试中使用的模型。
非生成性 | 生成性 | |
---|---|---|
CV | SegmentAnything1 | StableDiffusion2 |
NLP | BERT3 | Gemma4, Mistral5 |
我们并不是在测量每个框架所能达到的最佳性能,而是常见用户工作流的一键可用性能。考虑到这一目标,我们利用了 KerasCV 和 KerasNLP 中现有的实现,采用了 Keras 版本的模型。
所有基准测试均在 Google Cloud Compute Engine 的 a2-highgpu-1g
机器类型上使用一台具有 40GB GPU 内存的 NVIDIA A100 GPU 进行,配备 12 个 vCPU 和 85GB 主机内存。
表 2 显示了每步的基准测试结果,以毫秒为单位。每一步涉及对单个数据批次的训练或预测。结果取自 100 步的平均值,排除第一步,因为第一步包含模型创建和编译的开销。
为了公平比较,如果模型和任务(拟合或预测)相同,我们在框架之间使用相同的批处理大小。然而,对于不同模型和任务,由于它们的大小和架构不同,我们使用不同的批处理大小以避免内存不足(过大)或 GPU 利用率不足(过小)。
对于大型语言模型(Gemma 和 Mistral),我们也使用相同的批处理大小,因为它们属于相同的模型类型,参数数量相似(7B)。我们还对文本生成进行了基准测试,批处理大小为 1,因为这是用户广泛要求的。我们在训练和推理中使用了 bfloat16
精度,并使用 LoRA6 进行训练(微调)。
为了测量一键可用性能,我们尽量使用所有默认设置。例如,使用高级 API(例如使用 Keras model.fit()
)进行尽可能少的配置。
请注意,这与测量特定硬件/框架/模型组合的优化实施是相当不同的。有关不同框架的最佳优化结果,请参阅 MLPerf。
表 2:基准测试结果。速度以 ms/步为单位测量。数字越小越好。
批处理 大小 |
Keras 2 (TensorFlow) |
Keras 3 (TensorFlow) |
Keras 3 (JAX) |
Keras 3 (PyTorch) (急切模式) |
Keras 3 (最佳) |
|
---|---|---|---|---|---|---|
SegmentAnything (拟合) |
1 | 386.93 | 355.25 | 361.69 | 1,388.87 | 355.25 |
SegmentAnything (预测) |
4 | 1,859.27 | 438.50 | 376.34 | 1,720.96 | 376.34 |
稳定扩散 (拟合) |
8 | 1,023.21 | 392.24 | 391.21 | 823.44 | 391.21 |
稳定扩散 (预测) |
13 | 649.71 | 616.04 | 627.27 | 1,337.17 | 616.04 |
BERT (拟合) |
32 | 486.00 | 214.49 | 222.37 | 808.68 | 214.49 |
BERT (预测) |
256 | 470.12 | 466.01 | 418.72 | 1,865.98 | 418.72 |
Gemma (拟合) |
8 | NA | 232.52 | 273.67 | 525.15 | 232.52 |
Gemma (生成) |
32 | NA | 1,134.91 | 1,128.21 | 7,952.67* | 1,128.21 |
Gemma (生成) |
1 | NA | 758.57 | 703.46 | 7,649.40* | 703.46 |
Mistral (拟合) |
8 | NA | 185.92 | 213.22 | 452.12 | 185.92 |
Mistral (生成) |
32 | NA | 966.06 | 957.25 | 10,932.59* | 957.25 |
Mistral (生成) |
1 | NA | 743.28 | 679.30 | 11,054.67* | 679.30 |
* 此时使用 PyTorch 后端进行 LLM 推理的速度异常缓慢,因为 KerasNLP 使用静态序列填充,而不是 HuggingFace。这个问题将在不久的将来得到解决。
Keras 的三个后端各具独特优势。至关重要的是,从性能的角度来看,不存在一个后端能够始终超越其他后端。最快的后端通常取决于您特定的模型架构。
这突显了在追求最佳性能时框架可选性的价值。Keras 3 使您能够无缝切换后端,确保您找到与模型理想匹配的后端。
我们还计算了 Keras 3(使用其最佳性能的后端)相对于 Keras 2 的吞吐量(步数/ms)增加情况,结果如图所示。
图1:Keras 3 相对于 Keras 2 的速度提升,以吞吐量(步数/毫秒)为衡量标准
Keras 3 在所有基准模型中始终优于 Keras 2,在许多情况下都有显着的速度提升。SegmentAnything 推断获得了 380% 的显着提升,StableDiffusion 训练吞吐量增加了超过 150%,而 BERT 训练吞吐量也上升了超过 100%。
重要的是,即使您只是升级到 Keras 3 并继续使用 TensorFlow 后端,您仍然会看到性能提升。这主要是因为 Keras 2 直接使用了更多的 TensorFlow 融合操作,这在某些用例中可能并不理想于 XLA 编译。
框架性能在很大程度上依赖于特定的模型。Keras 3 使您能够为您的任务选择最快的框架——这个选择几乎总是能超越 Keras 2。
1 Kirillov, Alexander, et al. "Segment anything." ICCV (2023).
2 Rombach, Robin, et al. "高分辨率图像合成与潜在扩散模型。" CVPR (2022).
3 Kenton, Jacob, et al. "BERT: 语言理解的深度双向 Transformer 预训练。" NAACL (2019).
4 Banks, Jeanine, et al. "Gemma: 引入新的最先进开放模型。" The Keyword, Google (2024).
5 Jiang, Albert Q., et al. "Mistral 7B." arXiv 预印本 arXiv:2310.06825 (2023).
6 Hu, Edward J., et al. "Lora: 大型语言模型的低秩适应。" ICLR (2022).