调查回归问题#

所以你更新了JAX并且遇到了速度下降的问题?你有一些时间并且准备好调查这个问题了吗?让我们先创建一个JAX问题。但如果你能确定触发这个问题的提交,那将对我们非常有帮助。

本文档解释了我们如何识别导致 15% 性能下降 的提交。

步骤#

如果复现器足够快,这可以很容易地完成。这是一种暴力方法,而不是二分法,但如果复现器足够快,它效果很好。这确保你总是测试兼容的 XLA 和 JAX 提交。它还限制了 XLA 的重新编译。

以下是一个建议的调查策略:

  1. 你可以在两个版本之间对夜间容器进行暴力测试。

  2. 在保持 XLA 和 JAX 同步的同时进行每小时重新编译。

  3. 最终验证:可能需要手动检查几个提交(或使用 git bisect)。

夜间调查#

这可以通过使用 NVIDIA JAX-Toolbox 夜间容器 来完成。

  • 有些日子,错误会阻止容器被构建,或者存在临时的回归问题。只需丢弃那些日子。

  • 因此,你应该最终确定一个特定的日子或几天,在这些日子里回归发生。

  • 要自动化这个过程,你需要两个 Python 脚本:

    • test_runner.sh: 将启动容器并开始测试。

    • test.sh: 将安装缺失的依赖项并运行测试

以下是用于该问题的实际示例脚本:https://github.com/google/jax/issues/17686

  • test_runner.sh:

  for m in 7 8 9; do
    for d in `seq -w 1 30`; do
      docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
    done
  Done
  • test.sh:

  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  git clone https://github.com/Autodesk/XLB
  cd XLB
  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

然后你可以对每个输出进行 grep 操作,以查看回归发生的时间:grep MLUPS OUT*。以下是我们得到的结果:

OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25:  File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634

这发现8-24是好的,但8-26是坏的。在8-25有另一个问题阻止了结果的获取。所以我们需要在8-24和8-26之间每小时进行调查。之前有一个较小的减速,让我们在这个例子中忽略它。它将只是在那些日期之间的另一个每小时调查。

每小时调查#

这会在两个日期之间的每小时检查一次 JAX 和 XLA,重新构建所有内容并运行测试。脚本的结构不同。我们启动工作容器并保持它。然后在其中,我们只触发增量 XLA 构建,除了第一次构建。因此,第一次迭代后速度会快得多。

  • test_runner2.sh:

  # Execute this script inside the container:
  # docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
  cd /opt/xla-source
  git remote update
  cd /opt/jax-source
  git remote update
  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  cd /tmp
  git clone https://github.com/Autodesk/XLB
  cd XLB

  for d in `seq -w 24 26`; do
      for h in `seq -w 0 24`; do
          echo $m $d $h
          /bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
      done
  done
  • test2.sh:

  echo "param: $@"
  cd /opt/xla-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q
  cd /opt/jax-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q

  rm /opt/jax-source/dist/jax*.whl
  build-jax.sh # The script is in the nightly container

  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

现在,你可以在新的输出文件上执行 grep 命令,以查看问题出现在哪些时间段。

最终验证#

通过这个,你需要检查在那几个小时内的 JAX 和 XLA 历史。可能有一些提交需要测试。如果你想更高级,可以使用 git bisect。

这可以改进吗?#

是的!如果这是一个崩溃回归,能够进行二分查找将会很有用。但这会更复杂。如果有人想贡献这样的指令,请提交一个PR ;)

对于速度回归,二分法可能会隐藏一些信息。我们不会那么容易地看到这里有两个回归。