TorchModuleWrapper
classkeras.layers.TorchModuleWrapper(module, name=None, **kwargs)
Torch 模块包装层.
TorchModuleWrapper
是一个包装类,可以将任何 torch.nn.Module
转换为 Keras 层,特别是通过使其参数可被 Keras 跟踪.
TorchModuleWrapper
仅兼容 PyTorch 后端,不能与 TensorFlow 或 JAX 后端一起使用.
参数:
module: torch.nn.Module
实例.如果它是 LazyModule
实例,则必须在将其传递给 TorchModuleWrapper
之前初始化其参数(例如,通过调用一次).
name: 层的名称(字符串).
示例:
以下是如何将 TorchModuleWrapper
与普通 PyTorch 模块一起使用的示例.
import torch.nn as nn
import torch.nn.functional as F
import keras
from keras.src.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 如果包含参数,使用 `TorchModuleWrapper` 包装 `torch.nn.Module`
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("# Output shape", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)