fit()
中的行为作者: fchollet
创建日期: 2023/06/27
最后修改: 2023/06/27
描述: 使用 JAX 重写模型类的训练步骤。
当你进行监督学习时,可以使用 fit()
,一切都很顺利。
当你需要控制每一个细节时,可以完全从零开始编写自己的训练循环。
但是如果你需要一个自定义的训练算法,但又希望利用 fit()
的便利功能,例如回调、内置分布支持或步伐融合,该怎么办?
Keras 的一个核心原则是 复杂性逐步披露。你应该始终能够逐步进入更低级的工作流程。如果高级功能与你的用例不完全匹配,你不应该跌入万丈深渊。你应该能够在保持相应的高级便利性的同时,获得对细节的更多控制。
当你需要自定义 fit()
的行为时,你应该 重写 Model
类的训练步骤函数。这是 fit()
在每个数据批次上调用的函数。然后你将能够像往常一样调用 fit()
,并且它将运行你自己的学习算法。
请注意,这种模式并不妨碍你使用功能性 API 构建模型。无论你是构建 Sequential
模型、功能性 API 模型,还是子类化模型,你都可以这样做。
让我们看看这是如何工作的。
import os
# 本指南只能在 JAX 后端运行。
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import numpy as np
让我们从一个简单的示例开始:
keras.Model
的子类。compute_loss_and_updates()
方法
来计算损失以及模型中不可训练变量的更新值。内部调用 stateless_call()
和
内置的 compute_loss()
。train_step()
方法,以计算当前的
指标值(包括损失)以及可训练变量、优化器变量和指标变量的更新值。请注意,你还可以通过以下方式考虑 sample_weight
参数:
x, y, sample_weight = data
sample_weight
传递给 compute_loss()
stateless_update_state()
中将 sample_weight
与 y
和 y_pred
一起传递给指标class CustomModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data
# 获取梯度函数。
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# 计算梯度。
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)
# 更新可训练变量和优化器变量。
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 更新指标。
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# 返回指标日志和更新后的状态变量。
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
让我们试试这个:
# 构建并编译 CustomModel 的实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# 像往常一样使用 `fit`
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - mae: 1.0022 - loss: 1.2464
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 198us/step - mae: 0.5811 - loss: 0.4912
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 231us/step - mae: 0.4386 - loss: 0.2905
<keras.src.callbacks.history.History at 0x14da599c0>
自然地,你可以跳过在 compile()
中传递损失函数,而是手动在 train_step
中处理一切。度量也是如此。
这是一个更低层次的例子,只使用 compile()
来配置优化器:
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
self.loss_fn = keras.losses.MeanSquaredError()
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.loss_fn(y, y_pred)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data
# 获取梯度函数。
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# 计算梯度。
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)
# 更新可训练变量和优化器变量。
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 更新指标。
loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
loss_tracker_vars = self.loss_tracker.stateless_update_state(
loss_tracker_vars, loss
)
mae_metric_vars = self.mae_metric.stateless_update_state(
mae_metric_vars, y, y_pred
)
logs = {}
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
loss_tracker_vars
)
logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)
new_metrics_vars = loss_tracker_vars + mae_metric_vars
# 返回指标日志和更新后的状态变量。
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
@property
def metrics(self):
# 我们在这里列出我们的 `Metric` 对象,以便在每个纪元开始时
# 或在 `evaluate()` 开始时可以自动调用 `reset_states()`。
return [self.loss_tracker, self.mae_metric]
# 构建 CustomModel 的实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# 我们在这里不传递损失或指标。
model.compile(optimizer="adam")
# 就像往常一样使用 `fit` -- 你可以使用回调等。
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.6085 - mae: 0.6580
Epoch 2/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 215us/step - loss: 0.2630 - mae: 0.4141
Epoch 3/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 202us/step - loss: 0.2271 - mae: 0.3835
Epoch 4/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 192us/step - loss: 0.2093 - mae: 0.3714
Epoch 5/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 194us/step - loss: 0.2188 - mae: 0.3818
<keras.src.callbacks.history.History at 0x14de01420>
如果你想对此调用 model.evaluate()
进行相同的操作?那么你将以完全相同的方式重写 test_step
。这看起来是这样的:
class CustomModel(keras.Model):
def test_step(self, state, data):
# 解压数据。
x, y = data
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state
# 计算预测值和损失。
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=False,
)
loss = self.compute_loss(x, y, y_pred)
# 更新指标。
new_metrics_vars = []
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# 返回指标日志和更新后的状态变量。
state = (
trainable_variables,
non_trainable_variables,
new_metrics_vars,
)
return logs, state
# 构建 CustomModel 的实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# 使用自定义 test_step 进行评估
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 973us/step - mae: 0.7887 - loss: 0.8385
[0.8385222554206848, 0.7956181168556213]
就这些!