Autobatching for Bayesian Inference#
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import numpy as np
import scipy as sp
Generate a fake binary classification dataset#
np.random.seed(10009)
num_features = 10
num_points = 100
true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the log-joint function for the model#
We’ll write a non-batched version, a manually batched version, and an autobatched version.
Non-batched#
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
log_joint(np.random.randn(num_features))
Array(-213.23558, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))
Manually batched#
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=-1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
axis=-1)
return result
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,
-163.02911377, -143.84848022, -160.28771973, -113.77169037,
-126.60544586, -190.81988525], dtype=float32)
Autobatched with vmap#
It just works.
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,
-163.02911377, -143.84848022, -160.28771973, -113.77169037,
-126.60544586, -190.81988525], dtype=float32)
Self-contained variational inference example#
A little code is copied from above.
Set up the (batched) log-joint function#
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient#
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGD#
def normal_sample(key, shape):
"""Convenience function for quasi-stateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 -180.85391235351562
10 -113.06047058105469
20 -102.73725891113281
30 -99.78732299804688
40 -98.90898895263672
50 -98.29743957519531
60 -98.18630981445312
70 -97.5797348022461
80 -97.28599548339844
90 -97.46998596191406
100 -97.47715759277344
110 -97.5806884765625
120 -97.49433898925781
130 -97.50270080566406
140 -96.86398315429688
150 -97.44197082519531
160 -97.06938934326172
170 -96.84031677246094
180 -97.21339416503906
190 -97.56500244140625
200 -97.26395416259766
210 -97.11984252929688
220 -97.39595794677734
230 -97.16830444335938
240 -97.11840057373047
250 -97.24346160888672
260 -97.29786682128906
270 -96.69286346435547
280 -96.96443176269531
290 -97.3005599975586
300 -96.63589477539062
310 -97.0351791381836
320 -97.52906799316406
330 -97.2880630493164
340 -97.07324981689453
350 -97.15620422363281
360 -97.25880432128906
370 -97.19515228271484
380 -97.13092803955078
390 -97.11730194091797
400 -96.93872833251953
410 -97.26676940917969
420 -97.35321044921875
430 -97.2100830078125
440 -97.28434753417969
450 -97.16310119628906
460 -97.26123809814453
470 -97.21342468261719
480 -97.23995971679688
490 -97.1491470336914
500 -97.23527526855469
510 -96.93415832519531
520 -97.21208190917969
530 -96.82577514648438
540 -97.01283264160156
550 -96.9417724609375
560 -97.16526794433594
570 -97.29165649414062
580 -97.42940521240234
590 -97.24371337890625
600 -97.15219116210938
610 -97.4984359741211
620 -96.99072265625
630 -96.88955688476562
640 -96.89968872070312
650 -97.13794708251953
660 -97.43705749511719
670 -96.99232482910156
680 -97.15624237060547
690 -97.1869125366211
700 -97.1115951538086
710 -97.78104400634766
720 -97.23224639892578
730 -97.16204071044922
740 -96.99580383300781
750 -96.66720581054688
760 -97.16795349121094
770 -97.51432037353516
780 -97.28899383544922
790 -96.91226959228516
800 -97.17100524902344
810 -97.29046630859375
820 -97.16242980957031
830 -97.19109344482422
840 -97.5638427734375
850 -97.00192260742188
860 -96.86555480957031
870 -96.76338195800781
880 -96.83660125732422
890 -97.121826171875
900 -97.09553527832031
910 -97.06825256347656
920 -97.1194839477539
930 -96.87931823730469
940 -97.45622253417969
950 -96.69277954101562
960 -97.29376983642578
970 -97.33528137207031
980 -97.349609375
990 -97.09675598144531
Display the results#
Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
<matplotlib.legend.Legend at 0x7f90aed84860>
