自动批处理用于贝叶斯推断#

在Colab中打开 在Kaggle中打开

本笔记本演示了一个简单的贝叶斯推断示例,其中自动批处理使用户代码更易于编写、阅读,并且不太可能包含错误。

灵感来自于@davmre的笔记本。

import matplotlib.pyplot as plt

import jax

import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import numpy as np
import scipy as sp

生成一个虚假的二元分类数据集#

np.random.seed(10009)

num_features = 10
num_points = 100

true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)

编写模型的对数联合函数#

我们将编写一个非批处理版本,一个手动批处理版本,以及一个自动批处理版本。

非批处理#

def log_joint(beta):
    result = 0.
    # 请注意,`jnp.sum` 函数未提供 `axis` 参数。
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result
log_joint(np.random.randn(num_features))
Array(-213.2356, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)

  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)]

手动批处理#

def batched_log_joint(beta):
    result = 0.
    # 在这里(以及下面),`sum` 需要一个 `axis` 参数。最糟糕的情况是,忘记设置轴。
    # 或者设置不当会导致错误;最糟糕的是,它可能会悄然改变
    # 模型的语义。
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # 注意多次转置。正确处理这个问题并不复杂,
    # but it's also not totally mindless. (I didn't get it right on the first
    # 尝试。
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta)
Array([-147.84033, -207.02205, -109.26076, -243.80833, -163.0291 ,
       -143.8485 , -160.28773, -113.7717 , -126.60544, -190.81989],      dtype=float32)

自动批处理与 vmap#

它就是这么简单。

vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Array([-147.84033, -207.02205, -109.26076, -243.80833, -163.0291 ,
       -143.8485 , -160.28773, -113.7717 , -126.60544, -190.81989],      dtype=float32)

自包含的变分推断示例#

从上面复制了一些代码。

设置(批处理)日志联合函数#

@jax.jit
def log_joint(beta):
    result = 0.
    # 注意,`jnp.sum` 没有提供 `axis` 参数。
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint))

定义ELBO及其梯度#

def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))

elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))

使用随机梯度下降优化变分下界 (ELBO)#

def normal_sample(key, shape):
    """准状态随机数生成器的便捷函数。"""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)

normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.key(10003)

beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val))
0	-180.8538818359375
10	-113.06045532226562
20	-102.73727416992188
30	-99.787353515625
40	-98.90898132324219
50	-98.29745483398438
60	-98.18633270263672
70	-97.57972717285156
80	-97.28599548339844
90	-97.46996307373047
100	-97.47715759277344
110	-97.5806655883789
120	-97.4943618774414
130	-97.50271606445312
140	-96.86396026611328
150	-97.44197845458984
160	-97.06941223144531
170	-96.84027862548828
180	-97.21337127685547
190	-97.56502532958984
200	-97.26397705078125
210	-97.11979675292969
220	-97.39595794677734
230	-97.16831970214844
240	-97.118408203125
250	-97.24345397949219
260	-97.29788970947266
270	-96.69285583496094
280	-96.96438598632812
290	-97.30055236816406
300	-96.63592529296875
310	-97.03518676757812
320	-97.52909851074219
330	-97.28813171386719
340	-97.07322692871094
350	-97.15620422363281
360	-97.25881958007812
370	-97.19515228271484
380	-97.13092041015625
390	-97.11727905273438
400	-96.938720703125
410	-97.26676940917969
420	-97.35322570800781
430	-97.21007537841797
440	-97.28436279296875
450	-97.16307830810547
460	-97.2612533569336
470	-97.21343231201172
480	-97.23997497558594
490	-97.14913177490234
500	-97.23527526855469
510	-96.93419647216797
520	-97.21209716796875
530	-96.82575988769531
540	-97.01285552978516
550	-96.94176483154297
560	-97.16520690917969
570	-97.2916488647461
580	-97.42941284179688
590	-97.24371337890625
600	-97.15222930908203
610	-97.49844360351562
620	-96.99069213867188
630	-96.88956451416016
640	-96.89968872070312
650	-97.137939453125
660	-97.43706512451172
670	-96.99235534667969
680	-97.15623474121094
690	-97.1869125366211
700	-97.11161041259766
710	-97.78104400634766
720	-97.23226165771484
730	-97.16206359863281
740	-96.99581909179688
750	-96.66722106933594
760	-97.16795349121094
770	-97.51435089111328
780	-97.28900146484375
790	-96.91226196289062
800	-97.17098999023438
810	-97.29048156738281
820	-97.16242218017578
830	-97.1910629272461
840	-97.56382751464844
850	-97.00194549560547
860	-96.86555480957031
870	-96.76338195800781
880	-96.83661651611328
890	-97.12178802490234
900	-97.09554290771484
910	-97.06825256347656
920	-97.11947631835938
930	-96.87930297851562
940	-97.45625305175781
950	-96.69279479980469
960	-97.29376220703125
970	-97.3353042602539
980	-97.34962463378906
990	-97.09675598144531

显示结果#

覆盖率虽然不是我们期望的那么好,但也不错,而且没有人说变分推断是精确的。

plt.figure(figsize=(7, 7))
plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars')
plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
plt.xlabel('True beta')
plt.ylabel('Estimated beta')
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x120e86e50>
../_images/fedfad3b9066c8d80c91bec09ee892a87b8a0775e64b4fc56b8b20f9e7deef1c.png