Autobatching for Bayesian Inference#

Open in Colab Open in Kaggle

This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

Inspired by a notebook by @davmre.

import functools
import itertools
import re
import sys
import time

from matplotlib.pyplot import *

import jax

from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import numpy as np
import scipy as sp

Generate a fake binary classification dataset#

np.random.seed(10009)

num_features = 10
num_points = 100

true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)

Write the log-joint function for the model#

We’ll write a non-batched version, a manually batched version, and an autobatched version.

Non-batched#

def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result
log_joint(np.random.randn(num_features))
Array(-213.23558, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)

  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))

Manually batched#

def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta)
Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,
       -163.02911377, -143.84848022, -160.28771973, -113.77169037,
       -126.60544586, -190.81988525], dtype=float32)

Autobatched with vmap#

It just works.

vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,
       -163.02911377, -143.84848022, -160.28771973, -113.77169037,
       -126.60544586, -190.81988525], dtype=float32)

Self-contained variational inference example#

A little code is copied from above.

Set up the (batched) log-joint function#

@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint))

Define the ELBO and its gradient#

def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
 
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))

Optimize the ELBO using SGD#

def normal_sample(key, shape):
    """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)

normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.PRNGKey(10003)

beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val))
0	-180.85391235351562
10	-113.06047058105469
20	-102.73725891113281
30	-99.78732299804688
40	-98.90898895263672
50	-98.29743957519531
60	-98.18630981445312
70	-97.5797348022461
80	-97.28599548339844
90	-97.46998596191406
100	-97.47715759277344
110	-97.5806884765625
120	-97.49433898925781
130	-97.50270080566406
140	-96.86398315429688
150	-97.44197082519531
160	-97.06938934326172
170	-96.84031677246094
180	-97.21339416503906
190	-97.56500244140625
200	-97.26395416259766
210	-97.11984252929688
220	-97.39595794677734
230	-97.16830444335938
240	-97.11840057373047
250	-97.24346160888672
260	-97.29786682128906
270	-96.69286346435547
280	-96.96443176269531
290	-97.3005599975586
300	-96.63589477539062
310	-97.0351791381836
320	-97.52906799316406
330	-97.2880630493164
340	-97.07324981689453
350	-97.15620422363281
360	-97.25880432128906
370	-97.19515228271484
380	-97.13092803955078
390	-97.11730194091797
400	-96.93872833251953
410	-97.26676940917969
420	-97.35321044921875
430	-97.2100830078125
440	-97.28434753417969
450	-97.16310119628906
460	-97.26123809814453
470	-97.21342468261719
480	-97.23995971679688
490	-97.1491470336914
500	-97.23527526855469
510	-96.93415832519531
520	-97.21208190917969
530	-96.82577514648438
540	-97.01283264160156
550	-96.9417724609375
560	-97.16526794433594
570	-97.29165649414062
580	-97.42940521240234
590	-97.24371337890625
600	-97.15219116210938
610	-97.4984359741211
620	-96.99072265625
630	-96.88955688476562
640	-96.89968872070312
650	-97.13794708251953
660	-97.43705749511719
670	-96.99232482910156
680	-97.15624237060547
690	-97.1869125366211
700	-97.1115951538086
710	-97.78104400634766
720	-97.23224639892578
730	-97.16204071044922
740	-96.99580383300781
750	-96.66720581054688
760	-97.16795349121094
770	-97.51432037353516
780	-97.28899383544922
790	-96.91226959228516
800	-97.17100524902344
810	-97.29046630859375
820	-97.16242980957031
830	-97.19109344482422
840	-97.5638427734375
850	-97.00192260742188
860	-96.86555480957031
870	-96.76338195800781
880	-96.83660125732422
890	-97.121826171875
900	-97.09553527832031
910	-97.06825256347656
920	-97.1194839477539
930	-96.87931823730469
940	-97.45622253417969
950	-96.69277954101562
960	-97.29376983642578
970	-97.33528137207031
980	-97.349609375
990	-97.09675598144531

Display the results#

Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.

figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
<matplotlib.legend.Legend at 0x7f90aed84860>
../_images/37e1adad623ccc69efc5faff71e09862552e619ecf6cdff2402f718c68a8977d.png