作者: fchollet
创建日期: 2023/06/25
最后修改: 2023/06/25
描述: 在 JAX 中编写低级训练和评估循环。
import os
# 本指南只能在 jax 后端运行。
os.environ["KERAS_BACKEND"] = "jax"
import jax
# 我们导入 TF 以便可以使用 tf.data。
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,fit()
和 evaluate()
。
它们的用法在指南
使用内置方法进行训练和评估 中有所介绍。
如果你想自定义模型的学习算法,同时仍然利用 fit()
的便利性
(例如,使用 fit()
训练 GAN),你可以子类化 Model
类并
实现你自己的 train_step()
方法,
该方法在 fit()
期间被反复调用。
现在,如果你想要非常低级的控制训练和评估,你应该从头编写自己的训练和评估循环。这就是本指南的内容。
要编写自定义训练循环,我们需要以下成分:
keras.optimizers
中的优化器,或者
来自 optax
包的优化器。tf.data
加载数据,
所以我们也将使用它。让我们把它们列出来。
首先,让我们获取模型和 MNIST 数据集:
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# 准备训练数据集。
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# 保留 10,000 个样本用于验证。
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# 准备训练数据集。
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# 准备验证数据集。
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下来,这里是损失函数和优化器。 在这种情况下,我们将使用 Keras 优化器。
# 实例化一个损失函数。
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# 实例化一个优化器。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
让我们使用自定义训练循环和迷你批量梯度来训练我们的模型。
在 JAX 中,梯度是通过元编程计算的:你调用 jax.grad
(或 jax.value_and_grad
)对一个函数进行调用,以创建该函数的梯度计算函数。
所以我们首先需要的是一个返回损失值的函数。这就是我们将用来生成梯度函数的函数。类似于这样:
def compute_loss(x, y):
...
return loss
一旦你有了这样的函数,你可以通过元编程计算梯度,如下所示:
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,你不仅想获取梯度值,还想获取损失值。你可以使用 jax.value_and_grad
而不是 jax.grad
来实现这一点:
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切必须是纯无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)必须作为函数输入传递,并且在正向传递期间更新的任何变量必须作为函数输出返回。该函数不能有副作用。
在正向传递期间,Keras 模型的非训练变量可能会被更新。这些变量可以是,例如,RNG 种子状态变量或 BatchNormalization 统计数据。我们需要返回这些变量。所以我们需要类似这样的东西:
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦你有了这样的函数,你可以通过在 value_and_grad
中指定 has_aux
来获取梯度函数:它告诉 JAX 损失计算函数返回的不仅仅是损失。请注意,损失应该始终是第一个输出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
现在我们已经建立了基础,让我们实现这个 compute_loss_and_updates
函数。Keras 模型具有一个 stateless_call
方法,这在这里会派上用场。它的工作方式与 model.__call__
一样,但需要你显式传递模型中所有变量的值,并且它不仅返回 __call__
的输出,还返回(可能已更新的)非可训练变量。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
让我们获取梯度函数:
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下来,让我们实现端到端训练步骤的函数,该函数将同时运行前向传递,计算损失,计算梯度,还使用优化器更新可训练变量。这个函数也需要是无状态的,因此它将作为输入一个 state
元组,其中包含我们将要使用的每个状态元素:
trainable_variables
和 non_trainable_variables
:模型的变量。optimizer_variables
:优化器的状态变量,例如动量累加器。要更新可训练变量,我们使用优化器的无状态方法 stateless_apply
。这相当于 optimizer.apply()
,但它总是需要传递 trainable_variables
和 optimizer_variables
。它返回更新后的可训练变量和更新后的优化器变量。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
grads, trainable_variables, optimizer_variables
)
# 返回更新后的状态
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
加快速度默认情况下,JAX 操作是立即执行的,就像在 TensorFlow eager 模式和 PyTorch eager 模式中一样。并且就像 TensorFlow eager 模式和 PyTorch eager 模式一样,它非常慢——即发模式更适合作为调试环境,而不是实际工作的方式。让我们通过编译我们的 train_step
来加快速度。
当你有一个无状态的 JAX 函数时,你可以通过 @jax.jit
装饰器将其编译为 XLA。它将在第一次执行时被追踪,在随后的执行中,你将执行已追踪的图(这与 @tf.function(jit_compile=True)
非常相似)。让我们尝试一下:
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 返回更新后的状态
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
我们现在准备好训练我们的模型。训练循环本身非常简单:我们只需重复调用 loss, state = train_step(state, data)
。
注意:
tf.data.Dataset
生成的 TF 张量传递给我们的 JAX 函数之前,将其转换为 NumPy。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
训练损失(每批次)在步骤 0: 156.4785
到目前为止: 32 个样本
训练损失(每批次)在步骤 100: 2.5526
到目前为止: 3232 个样本
训练损失(每批次)在步骤 200: 1.8922
到目前为止: 6432 个样本
训练损失(每批次)在步骤 300: 1.2381
到目前为止: 9632 个样本
训练损失(每批次)在步骤 400: 0.4812
到目前为止: 12832 个样本
训练损失(每批次)在步骤 500: 2.3339
到目前为止: 16032 个样本
训练损失(每批次)在步骤 600: 0.5615
到目前为止: 19232 个样本
训练损失(每批次)在步骤 700: 0.6471
到目前为止: 22432 个样本
训练损失(每批次)在步骤 800: 1.6272
到目前为止: 25632 个样本
训练损失(每批次)在步骤 900: 0.9416
到目前为止: 28832 个样本
训练损失(每批次)在步骤 1000: 0.8152
到目前为止: 32032 个样本
训练损失(每批次)在步骤 1100: 0.8838
到目前为止: 35232 个样本
训练损失(每批次)在步骤 1200: 0.1278
到目前为止: 38432 个样本
训练损失(每批次)在步骤 1300: 1.9234
到目前为止: 41632 个样本
训练损失(每批次)在步骤 1400: 0.3413
到目前为止: 44832 个样本
训练损失(每批次)在步骤 1500: 0.2429
到目前为止: 48032 个样本
这里一个关键点是循环完全是无状态的 - 附加到模型的变量 (model.weights
) 在循环中从未被更新。它们的新值仅存储在 state
元组中。这意味着在某个时刻,在保存模型之前,您应该将新的变量值附加回模型。
只需对每个要更新的模型变量调用 variable.assign(new_value)
:
可训练变量、非可训练变量、优化器变量 = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
让我们在这个基本训练循环中添加指标监控。
您可以在这样从头编写的训练循环中轻松重用内置 Keras 指标(或您编写的自定义指标)。流程如下:
metric_variables
包含在 train_step
的参数中
和 compute_loss_and_updates
的参数中。compute_loss_and_updates
函数中调用 metric.stateless_update_state()
。
它等同于 update_state()
- 只是无状态的。train_step
之外
(在急切范围内),将新的指标变量值附加到指标对象
并调用 metric.result()
。metric.reset_state()
(通常在一个 epoch 结束时)让我们利用这些知识在训练结束时计算训练和验证数据上的 CategoricalAccuracy
:
# 获取一个新模型
model = get_model()
# 实例化一个优化器以训练模型。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# 实例化损失函数。
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# 准备指标。
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 返回更新后的状态
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我们还会准备一个评估步骤函数:
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
这里是我们的循环:
# 构建优化器变量。
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# 训练循环
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# 每100个批次记录一次日志。
if step % 100 == 0:
print(f"训练损失(每批次的损失)在步骤 {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"训练准确率: {train_acc_metric.result()}")
print(f"到目前为止: {(step + 1) * batch_size} 个样本")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# 验证循环
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# 每100个批次记录一次日志。
if step % 100 == 0:
print(f"验证损失(每批次的损失)在步骤 {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"验证准确率: {val_acc_metric.result()}")
print(f"到目前为止: {(step + 1) * batch_size} 个样本")
训练损失(每批次的损失)在步骤 0: 96.4990
训练准确率: 0.0625
到目前为止: 32 个样本
训练损失(每批次的损失)在步骤 100: 2.0447
训练准确率: 0.6064356565475464
到目前为止: 3232 个样本
训练损失(每批次的损失)在步骤 200: 2.0184
训练准确率: 0.6934079527854919
到目前为止: 6432 个样本
训练损失(每批次的损失)在步骤 300: 1.9111
训练准确率: 0.7303779125213623
到目前为止: 9632 个样本
训练损失(每批次的损失)在步骤 400: 1.8042
训练准确率: 0.7555330395698547
到目前为止: 12832 个样本
训练损失(每批次的损失)在步骤 500: 1.2200
训练准确率: 0.7659056782722473
到目前为止: 16032 个样本
训练损失(每批次的损失)在步骤 600: 1.3437
训练准确率: 0.7793781161308289
到目前为止: 19232 个样本
训练损失(每批次的损失)在步骤 700: 1.2409
训练准确率: 0.789318859577179
到目前为止: 22432 个样本
训练损失(每批次的损失)在步骤 800: 1.6530
训练准确率: 0.7977527976036072
到目前为止: 25632 个样本
训练损失(每批次的损失)在步骤 900: 0.4173
训练准确率: 0.8060488104820251
到目前为止: 28832 个样本
训练损失(每批次的损失)在步骤 1000: 0.5543
训练准确率: 0.8100025057792664
到目前为止: 32032 个样本
训练损失(每批次的损失)在步骤 1100: 1.2699
训练准确率: 0.8160762786865234
到目前为止: 35232 个样本
训练损失(每批次的损失)在步骤 1200: 1.2621
训练准确率: 0.8213468194007874
到目前为止: 38432 个样本
训练损失(每批次的损失)在步骤 1300: 0.8028
训练准确率: 0.8257350325584412
到目前为止: 41632 个样本
训练损失(每批次的损失)在步骤 1400: 1.0701
训练准确率: 0.8298090696334839
到目前为止: 44832 个样本
训练损失(每批次的损失)在步骤 1500: 0.3910
训练准确率: 0.8336525559425354
到目前为止: 48032 个样本
验证损失(每批次的损失)在步骤 0: 0.2482
验证准确率: 0.835365355014801
到目前为止: 32 个样本
验证损失(每批次的损失)在步骤 100: 1.1641
验证准确率: 0.8388938903808594
到目前为止: 3232 个样本
验证损失(每批次的损失)在步骤 200: 0.1201
验证准确率: 0.8428196907043457
到目前为止: 6432 个样本
验证损失(每批次的损失)在步骤 300: 0.0755
验证准确率: 0.8471122980117798
到目前为止: 9632 个样本
层和模型通过调用 self.add_loss(value)
的层在前向传播过程中递归跟踪任何产生的损失。前向传播结束时,结果的标量损失值列表可以通过属性 model.losses
获取。
如果你想使用这些损失组件,应该将它们相加并在训练步骤中将其添加到主要损失中。
考虑这个层,它创建了一个活动正则化损失:
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
让我们构建一个使用它的非常简单的模型:
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# 将活动正则化插入为一层
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
下面是我们现在的 compute_loss_and_updates
函数:
return_losses=True
传递给 model.stateless_call()
。losses
相加并添加到主损失中。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses) # 如果有损失,则将其相加
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就这些!