开发者指南 / 在 JAX 中从头编写训练循环

在 JAX 中从头编写训练循环

作者: fchollet
创建日期: 2023/06/25
最后修改: 2023/06/25
描述: 在 JAX 中编写低级训练和评估循环。

在 Colab 中查看 GitHub 源码


设置

import os

# 本指南只能在 jax 后端运行。
os.environ["KERAS_BACKEND"] = "jax"

import jax

# 我们导入 TF 以便可以使用 tf.data。
import tensorflow as tf
import keras
import numpy as np

介绍

Keras 提供了默认的训练和评估循环,fit()evaluate()。 它们的用法在指南 使用内置方法进行训练和评估 中有所介绍。

如果你想自定义模型的学习算法,同时仍然利用 fit() 的便利性 (例如,使用 fit() 训练 GAN),你可以子类化 Model 类并 实现你自己的 train_step() 方法, 该方法在 fit() 期间被反复调用。

现在,如果你想要非常低级的控制训练和评估,你应该从头编写自己的训练和评估循环。这就是本指南的内容。


第一个端到端示例

要编写自定义训练循环,我们需要以下成分:

  • 当然是要训练的模型。
  • 一个优化器。你可以使用 keras.optimizers 中的优化器,或者 来自 optax 包的优化器。
  • 一个损失函数。
  • 一个数据集。在 JAX 生态系统中,标准的做法是通过 tf.data 加载数据, 所以我们也将使用它。

让我们把它们列出来。

首先,让我们获取模型和 MNIST 数据集:

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# 准备训练数据集。
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# 保留 10,000 个样本用于验证。
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# 准备训练数据集。
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# 准备验证数据集。
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下来,这里是损失函数和优化器。 在这种情况下,我们将使用 Keras 优化器。

# 实例化一个损失函数。
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# 实例化一个优化器。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

在 JAX 中获取梯度

让我们使用自定义训练循环和迷你批量梯度来训练我们的模型。

在 JAX 中,梯度是通过元编程计算的:你调用 jax.grad(或 jax.value_and_grad)对一个函数进行调用,以创建该函数的梯度计算函数。

所以我们首先需要的是一个返回损失值的函数。这就是我们将用来生成梯度函数的函数。类似于这样:

def compute_loss(x, y):
    ...
    return loss

一旦你有了这样的函数,你可以通过元编程计算梯度,如下所示:

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常,你不仅想获取梯度值,还想获取损失值。你可以使用 jax.value_and_grad 而不是 jax.grad 来实现这一点:

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)

JAX 计算是纯无状态的

在 JAX 中,一切必须是纯无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)必须作为函数输入传递,并且在正向传递期间更新的任何变量必须作为函数输出返回。该函数不能有副作用。

在正向传递期间,Keras 模型的非训练变量可能会被更新。这些变量可以是,例如,RNG 种子状态变量或 BatchNormalization 统计数据。我们需要返回这些变量。所以我们需要类似这样的东西:

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦你有了这样的函数,你可以通过在 value_and_grad 中指定 has_aux 来获取梯度函数:它告诉 JAX 损失计算函数返回的不仅仅是损失。请注意,损失应该始终是第一个输出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

现在我们已经建立了基础,让我们实现这个 compute_loss_and_updates 函数。Keras 模型具有一个 stateless_call 方法,这在这里会派上用场。它的工作方式与 model.__call__ 一样,但需要你显式传递模型中所有变量的值,并且它不仅返回 __call__ 的输出,还返回(可能已更新的)非可训练变量。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

让我们获取梯度函数:

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

训练步骤函数

接下来,让我们实现端到端训练步骤的函数,该函数将同时运行前向传递,计算损失,计算梯度,还使用优化器更新可训练变量。这个函数也需要是无状态的,因此它将作为输入一个 state 元组,其中包含我们将要使用的每个状态元素:

  • trainable_variablesnon_trainable_variables:模型的变量。
  • optimizer_variables:优化器的状态变量,例如动量累加器。

要更新可训练变量,我们使用优化器的无状态方法 stateless_apply。这相当于 optimizer.apply(),但它总是需要传递 trainable_variablesoptimizer_variables。它返回更新后的可训练变量和更新后的优化器变量。

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        grads, trainable_variables, optimizer_variables
    )
    # 返回更新后的状态
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

使用 jax.jit 加快速度

默认情况下,JAX 操作是立即执行的,就像在 TensorFlow eager 模式和 PyTorch eager 模式中一样。并且就像 TensorFlow eager 模式和 PyTorch eager 模式一样,它非常慢——即发模式更适合作为调试环境,而不是实际工作的方式。让我们通过编译我们的 train_step 来加快速度。

当你有一个无状态的 JAX 函数时,你可以通过 @jax.jit 装饰器将其编译为 XLA。它将在第一次执行时被追踪,在随后的执行中,你将执行已追踪的图(这与 @tf.function(jit_compile=True) 非常相似)。让我们尝试一下:

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # 返回更新后的状态
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

我们现在准备好训练我们的模型。训练循环本身非常简单:我们只需重复调用 loss, state = train_step(state, data)

注意:

  • 我们在将由 tf.data.Dataset 生成的 TF 张量传递给我们的 JAX 函数之前,将其转换为 NumPy。
  • 所有变量必须提前构建:模型必须构建,优化器必须构建。由于我们使用的是功能性 API 模型,它已经构建,但如果它是一个子类模型,你需要在一批数据上调用它以构建它。

构建优化器变量。

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
训练损失(每批次)在步骤 0: 156.4785
到目前为止: 32 个样本
训练损失(每批次)在步骤 100: 2.5526
到目前为止: 3232 个样本
训练损失(每批次)在步骤 200: 1.8922
到目前为止: 6432 个样本
训练损失(每批次)在步骤 300: 1.2381
到目前为止: 9632 个样本
训练损失(每批次)在步骤 400: 0.4812
到目前为止: 12832 个样本
训练损失(每批次)在步骤 500: 2.3339
到目前为止: 16032 个样本
训练损失(每批次)在步骤 600: 0.5615
到目前为止: 19232 个样本
训练损失(每批次)在步骤 700: 0.6471
到目前为止: 22432 个样本
训练损失(每批次)在步骤 800: 1.6272
到目前为止: 25632 个样本
训练损失(每批次)在步骤 900: 0.9416
到目前为止: 28832 个样本
训练损失(每批次)在步骤 1000: 0.8152
到目前为止: 32032 个样本
训练损失(每批次)在步骤 1100: 0.8838
到目前为止: 35232 个样本
训练损失(每批次)在步骤 1200: 0.1278
到目前为止: 38432 个样本
训练损失(每批次)在步骤 1300: 1.9234
到目前为止: 41632 个样本
训练损失(每批次)在步骤 1400: 0.3413
到目前为止: 44832 个样本
训练损失(每批次)在步骤 1500: 0.2429
到目前为止: 48032 个样本

这里一个关键点是循环完全是无状态的 - 附加到模型的变量 (model.weights) 在循环中从未被更新。它们的新值仅存储在 state 元组中。这意味着在某个时刻,在保存模型之前,您应该将新的变量值附加回模型。

只需对每个要更新的模型变量调用 variable.assign(new_value)

可训练变量非可训练变量优化器变量 = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

底层指标处理

让我们在这个基本训练循环中添加指标监控。

您可以在这样从头编写的训练循环中轻松重用内置 Keras 指标(或您编写的自定义指标)。流程如下:

  • 在循环开始时实例化指标
  • metric_variables 包含在 train_step 的参数中 和 compute_loss_and_updates 的参数中。
  • compute_loss_and_updates 函数中调用 metric.stateless_update_state()。 它等同于 update_state() - 只是无状态的。
  • 当您需要显示指标的当前值时,在 train_step 之外 (在急切范围内),将新的指标变量值附加到指标对象 并调用 metric.result()
  • 当您需要清除指标的状态时调用 metric.reset_state() (通常在一个 epoch 结束时)

让我们利用这些知识在训练结束时计算训练和验证数据上的 CategoricalAccuracy

# 获取一个新模型
model = get_model()

# 实例化一个优化器以训练模型。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# 实例化损失函数。
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# 准备指标。
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # 返回更新后的状态
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

我们还会准备一个评估步骤函数:

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

这里是我们的循环:

# 构建优化器变量。
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# 训练循环
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # 每100个批次记录一次日志。
    if step % 100 == 0:
        print(f"训练损失(每批次的损失)在步骤 {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"训练准确率: {train_acc_metric.result()}")
        print(f"到目前为止: {(step + 1) * batch_size} 个样本")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# 验证循环
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # 每100个批次记录一次日志。
    if step % 100 == 0:
        print(f"验证损失(每批次的损失)在步骤 {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"验证准确率: {val_acc_metric.result()}")
        print(f"到目前为止: {(step + 1) * batch_size} 个样本")
训练损失(每批次的损失)在步骤 0: 96.4990
训练准确率: 0.0625
到目前为止: 32 个样本
训练损失(每批次的损失)在步骤 100: 2.0447
训练准确率: 0.6064356565475464
到目前为止: 3232 个样本
训练损失(每批次的损失)在步骤 200: 2.0184
训练准确率: 0.6934079527854919
到目前为止: 6432 个样本
训练损失(每批次的损失)在步骤 300: 1.9111
训练准确率: 0.7303779125213623
到目前为止: 9632 个样本
训练损失(每批次的损失)在步骤 400: 1.8042
训练准确率: 0.7555330395698547
到目前为止: 12832 个样本
训练损失(每批次的损失)在步骤 500: 1.2200
训练准确率: 0.7659056782722473
到目前为止: 16032 个样本
训练损失(每批次的损失)在步骤 600: 1.3437
训练准确率: 0.7793781161308289
到目前为止: 19232 个样本
训练损失(每批次的损失)在步骤 700: 1.2409
训练准确率: 0.789318859577179
到目前为止: 22432 个样本
训练损失(每批次的损失)在步骤 800: 1.6530
训练准确率: 0.7977527976036072
到目前为止: 25632 个样本
训练损失(每批次的损失)在步骤 900: 0.4173
训练准确率: 0.8060488104820251
到目前为止: 28832 个样本
训练损失(每批次的损失)在步骤 1000: 0.5543
训练准确率: 0.8100025057792664
到目前为止: 32032 个样本
训练损失(每批次的损失)在步骤 1100: 1.2699
训练准确率: 0.8160762786865234
到目前为止: 35232 个样本
训练损失(每批次的损失)在步骤 1200: 1.2621
训练准确率: 0.8213468194007874
到目前为止: 38432 个样本
训练损失(每批次的损失)在步骤 1300: 0.8028
训练准确率: 0.8257350325584412
到目前为止: 41632 个样本
训练损失(每批次的损失)在步骤 1400: 1.0701
训练准确率: 0.8298090696334839
到目前为止: 44832 个样本
训练损失(每批次的损失)在步骤 1500: 0.3910
训练准确率: 0.8336525559425354
到目前为止: 48032 个样本
验证损失(每批次的损失)在步骤 0: 0.2482
验证准确率: 0.835365355014801
到目前为止: 32 个样本
验证损失(每批次的损失)在步骤 100: 1.1641
验证准确率: 0.8388938903808594
到目前为止: 3232 个样本
验证损失(每批次的损失)在步骤 200: 0.1201
验证准确率: 0.8428196907043457
到目前为止: 6432 个样本
验证损失(每批次的损失)在步骤 300: 0.0755
验证准确率: 0.8471122980117798
到目前为止: 9632 个样本

模型跟踪的损失的低级处理

层和模型通过调用 self.add_loss(value) 的层在前向传播过程中递归跟踪任何产生的损失。前向传播结束时,结果的标量损失值列表可以通过属性 model.losses 获取。

如果你想使用这些损失组件,应该将它们相加并在训练步骤中将其添加到主要损失中。

考虑这个层,它创建了一个活动正则化损失:

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

让我们构建一个使用它的非常简单的模型:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# 将活动正则化插入为一层
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

下面是我们现在的 compute_loss_and_updates 函数:

  • return_losses=True 传递给 model.stateless_call()
  • 将结果中的 losses 相加并添加到主损失中。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)  # 如果有损失,则将其相加
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

就这些!