JAX中的广义卷积#

在Colab中打开 在Kaggle中打开

JAX提供了多种接口来计算数据之间的卷积,包括:

对于基本的卷积操作,jax.numpyjax.scipy的操作通常已经足够。如果您想进行更一般的批量多维卷积,jax.lax函数是您应该开始的地方。

基本的一维卷积#

基本的一维卷积是通过 jax.numpy.convolve() 实现的,该接口为 numpy.convolve() 提供了 JAX 接口。下面是一个通过卷积实现的一维平滑的简单示例:

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp
import numpy as np

key = random.key(1701)

x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))

window = jnp.ones(10) / 10
y_smooth = jnp.convolve(y, window, mode='same')

plt.plot(x, y, 'lightgray')
plt.plot(x, y_smooth, 'black');
../_images/c35efc499faf7462bde6068d6ebb586f6fbc50e75dcf8823c84ae274801c0c0f.png

mode 参数控制边界条件的处理方式;在这里我们使用 mode='same' 来确保输出与输入具有相同的大小。

有关更多信息,请参见 jax.numpy.convolve() 文档,或与原始 numpy.convolve() 函数相关的文档。

基础N维卷积#

对于N维卷积,jax.scipy.signal.convolve() 提供了与 jax.numpy.convolve() 类似的接口,并扩展到N维。

例如,以下是基于与高斯滤波器的卷积进行图像去噪的一个简单方法:

from scipy import misc
import jax.scipy as jsp

fig, ax = plt.subplots(1, 3, figsize=(12, 5))

# 加载一张示例图像;使用mean()函数将其从RGB转换为灰度图像。
image = jnp.array(misc.face().mean(-1))
ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')

# 通过添加随机高斯噪声生成一个噪声版本
key = random.key(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax[1].imshow(noisy_image, cmap='binary_r')
ax[1].set_title('noisy')

# 使用二维高斯平滑核对噪声图像进行平滑处理。
x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
smooth_image = jsp.signal.convolve(noisy_image, window, mode='same')
ax[2].imshow(smooth_image, cmap='binary_r')
ax[2].set_title('smoothed');
/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22811/1094918267.py:7: DeprecationWarning: scipy.misc.face has been deprecated in SciPy v1.10.0; and will be completely removed in SciPy v1.12.0. Dataset methods have moved into the scipy.datasets module. Use scipy.datasets.face instead.
  image = jnp.array(misc.face().mean(-1))
../_images/680f94dd73ea113b11a54238d08d66abeffb9229029599a18003ac84f65dc6b8.png

与一维情况一样,我们使用 mode='same' 来指定我们希望如何处理边缘。有关N维卷积中可用选项的更多信息,请参见 jax.scipy.signal.convolve() 文档。

一般卷积#

对于在构建深度神经网络时经常有用的更一般类型的批处理卷积,JAX和XLA提供了非常通用的N维__conv_general_dilated__函数,但如何使用它并不明显。我们将给出一些常见用例的示例。

强烈推荐阅读关于卷积算子的家族调查,卷积算术指南!

接下来,让我们定义一个简单的对角边缘卷积核:

# 2D 核 - HWIO 布局
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += jnp.array([[1, 1, 0],
                     [1, 0,-1],
                     [0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
../_images/dd7641e945b192f1517147d9a1c3935179f6acdac606ea84ea0d86aaf2f5498e.png

我们将制作一个简单的合成图像:

# NHWC布局
img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
  x = 30 + 60*k
  y = 20 + 60*k
  img = img.at[0, x:x+10, y:y+10, k].set(1.0)

print("Original Image:")
plt.imshow(img[0]);
Original Image:
../_images/29e909ac5254e308bad710d4dc5e360f2821f7ca318502c500231885aa837854.png

lax.conv 和 lax.conv_with_general_padding#

这些是用于卷积的简单便利函数

️⚠️ 便利的 lax.convlax.conv_with_general_padding 辅助函数假定 NCHW 图像和 OIHW 核心。

from jax import lax
out = lax.conv(jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW 图像张量
               jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW 卷积核张量
               (1, 1),  # 窗口步幅
               'SAME') # 填充模式
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape:  (1, 3, 200, 198)
First output channel:
../_images/010556bcf221e024f8eaa6a6076da1ffb2316a8d56eebe2ee8c3c16e9a5b16c5.png
out = lax.conv_with_general_padding(
  jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW 图像张量
  jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW 卷积核张量
  (1, 1),  # 窗口步幅
  ((2,2),(2,2)), # 通用填充 2x2
  (1,1),  # lhs/图像膨胀
  (1,1))  # 右半边/核扩张
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape:  (1, 3, 202, 200)
First output channel:
../_images/34032081220009c2072739c9c68904d915cf7ee4f183c047d3c63f28d4f76fa4.png

维度数字定义 conv_general_dilated 的维度布局#

重要的参数是轴布局参数的三元组: (输入布局,内核布局,输出布局)

  • N - 批次维度

  • H - 空间高度

  • W - 空间宽度

  • C - 通道维度

  • I - 内核_输入_通道维度

  • O - 内核_输出_通道维度

⚠️ 为了展示维度数字的灵活性,我们选择了 NHWC 图像和 HWIO 内核约定用于下面的 lax.conv_general_dilated

dn = lax.conv_dimension_numbers(img.shape,     # 只有ndim重要,形状无关紧要。
                                kernel.shape,  # 只有ndim重要,形状无关紧要。
                                ('NHWC', 'HWIO', 'NHWC'))  # 关键部分
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))

SAME 填充,无步幅,无扩张#

out = lax.conv_general_dilated(img,    # lhs = 图像张量
                               kernel, # rhs = 卷积核张量
                               (1,1),  # 窗口步幅
                               'SAME', # 填充模式
                               (1,1),  # lhs/图像膨胀
                               (1,1),  # 右侧/内核扩张
                               dn)     # 维度编号 = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 200, 198, 3)
First output channel:
../_images/010556bcf221e024f8eaa6a6076da1ffb2316a8d56eebe2ee8c3c16e9a5b16c5.png

有效填充,无步幅,无膨胀#

out = lax.conv_general_dilated(img,     # lhs = 图像张量
                               kernel,  # rhs = 卷积核张量
                               (1,1),   # 窗口步幅
                               'VALID', # 填充模式
                               (1,1),   # lhs/图像膨胀
                               (1,1),   # rhs/内核扩张
                               dn)      # dimension_numbers = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 198, 196, 3) DIFFERENT from above!
First output channel:
../_images/9824c23ac96e76e2b2b3dc532b69a55d53d75d5c83163d873501d07ded0028f8.png

SAME 填充,2,2 步幅,无扩张#

out = lax.conv_general_dilated(img,    # lhs = 图像张量
                               kernel, # rhs = 卷积核张量
                               (2,2),  # 窗口步幅
                               'SAME', # 填充模式
                               (1,1),  # lhs/图像膨胀
                               (1,1),  # 右半轴/核扩张
                               dn)     # dimension_numbers = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 100, 99, 3)  <-- half the size of above
First output channel:
../_images/0e9bea07ce50502c53a0baf1c651a28666ddc79871c229e2eaa2a1d8089f7e9b.png

VALID填充,无步幅,右侧卷积核扩张 ~ Atrous卷积(过度以便说明)#

out = lax.conv_general_dilated(img,     # lhs = 图像张量
                               kernel,  # rhs = 卷积核张量
                               (1,1),   # 窗口步幅
                               'VALID', # 填充模式
                               (1,1),   # lhs/图像膨胀
                               (12,12), # 右半轴/核扩张
                               dn)      # dimension_numbers = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 176, 174, 3)
First output channel:
../_images/e2c2eb80e9bcbda0a06bd0a3ec4a243206c62beb02d880cf1fb2fd7da9c5ff6a.png

有效填充,无步幅,左侧是输入扩张 ~ 转置卷积#

out = lax.conv_general_dilated(img,               # lhs = 图像张量
                               kernel,            # rhs = 卷积核张量
                               (1,1),             # 窗口步幅
                               ((0, 0), (0, 0)),  # 填充模式
                               (2,2),             # lhs/图像膨胀
                               (1,1),             # rhs/内核扩张
                               dn)                # dimension_numbers = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 397, 393, 3) <-- larger than original!
First output channel:
../_images/54466faf066cd4df58855f3efe8f5fd4be520fabe28fbaf1cd7ae59dd0109dbb.png

我们可以使用最后一种方法,例如,实现_转置卷积_:

# 以下等同于tensorflow:
# N, H, W, C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# 转置卷积 = 180度核旋转加上左侧扩张
# 旋转内核180度:
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# 需要自定义输出填充:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img,     # lhs = 图像张量
                               kernel_rot,  # rhs = 卷积核张量
                               (1,1),   # 窗口步幅
                               padding, # 填充模式
                               (2,2),   # lhs/图像膨胀
                               (1,1),   # rhs/内核扩张
                               dn)      # 维度编号 = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 400, 396, 3) <-- transposed_conv
First output channel:
../_images/6d14f7c6860fb409102935f1cb9b0ed5ad6f93d9433a997bdf52904b8cd165f5.png

一维卷积#

您并不局限于2D卷积,下面是一个简单的一维示例:

# 一维核 - WIO布局
kernel = jnp.array([[[1, 0, -1], [-1,  0,  1]],
                    [[1, 1,  1], [-1, -1, -1]]],
                    dtype=jnp.float32).transpose([2,1,0])
# 一维数据 - NWC布局
data = np.zeros((1, 200, 2), dtype=jnp.float32)
for i in range(2):
  for k in range(2):
      x = 35*i + 30 + 60*k
      data[0, x:x+30, k] = 1.0

print("in shapes:", data.shape, kernel.shape)

plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NWC', 'WIO', 'NWC'))
print(dn)

out = lax.conv_general_dilated(data,   # lhs = 图像张量
                               kernel, # rhs = 卷积核张量
                               (1,),   # 窗口步幅
                               'SAME', # 填充模式
                               (1,),   # lhs/图像膨胀
                               (1,),   # 右半轴/核扩张
                               dn)     # 维度编号 = 左操作数, 右操作数, 输出维度排列
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape:  (1, 200, 2)
../_images/78fc31c8ffcf33926558abe0dbff31244444d74e1b39aac0a3b0f9ab7cbe4eec.png ../_images/04b2dc814cbc18ebc5a6eb9c8134dbb453160f66109ecbbf5d55ce0cd67cf26b.png

3D 卷积#

import matplotlib as mpl

# 随机3D核 - HWDIO布局
kernel = jnp.array([
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]],
  [[0, -1, 0], [-1, 0, -1], [0,  -1,  0]],
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]]],
  dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]

# 三维数据 - NHWDC布局
data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]

print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)

out = lax.conv_general_dilated(data,    # lhs = 图像张量
                               kernel,  # rhs = 卷积核张量
                               (1,1,1), # 窗口步幅
                               'SAME',  # 填充模式
                               (1,1,1), # lhs/图像膨胀
                               (1,1,1), # rhs/内核扩张
                               dn)      # 维度编号
print("out shape: ", out.shape)

# 制作一些简单的三维密度图:
def make_alpha(cmap):
  my_cmap = cmap(jnp.arange(cmap.N))
  my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
  return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape:  (1, 30, 30, 30, 1)
../_images/b4a1b95cb0fcdaea84a3e991deaadb1619aee0baf29815727484b2e42a095f1e.png ../_images/645be19fe70e82f41ad95b80ec41f1f77a92ad77517bd41af3a052313bb991c8.png