Sample callback¶
This notebook demonstrates the usage of the callback attribute in pm.sample
. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.
The sampling process can be interrupted by throwing a KeyboardInterrupt
from inside the callback.
use-cases for this callback include:
Stopping sampling when a number of effective samples is reached
Stopping sampling when there are too many divergences
Logging metrics to external tools (such as TensorBoard)
We’ll start with defining a simple model
[1]:
import numpy as np
import pymc3 as pm
X = np.array([1, 2, 3, 4, 5])
y = X * 2 + np.random.randn(len(X))
with pm.Model() as model:
intercept = pm.Normal("intercept", 0, 10)
slope = pm.Normal("slope", 0, 10)
mean = intercept + slope * X
error = pm.HalfCauchy("error", 1)
obs = pm.Normal("obs", mean, error, observed=y)
We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the pm.sample
[9]:
def my_callback(trace, draw):
if len(trace) >= 100:
raise KeyboardInterrupt()
with model:
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
print(len(trace))
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [error, slope, intercept]
There were 12 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5303940121554945, but should be close to 0.8. Try to increase the number of tuning steps.
Only one chain was sampled, this makes it impossible to run some convergence checks
100
Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we’ll need a bit of machinery to make this possible.
[20]:
def my_callback(trace, draw):
if len(trace) % 100 == 0:
print(len(trace))
with model:
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
100
200
100
300
400
200
500
300
400
500
The chain contains only diverging samples. The model is probably misspecified.
The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.
There were 18 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 9.211751427765233e-155, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
We can use the draw.chain
attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws.
[128]:
import arviz as az
class MyCallback:
def __init__(self, every=1000, max_rhat=1.05):
self.every = every
self.max_rhat = max_rhat
self.traces = {}
def __call__(self, trace, draw):
if draw.tuning:
return
self.traces[draw.chain] = trace
if len(trace) % self.every == 0:
multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
raise KeyboardInterrupt
with model:
trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
The estimated number of effective samples is smaller than 200 for some parameters.
[47]:
%load_ext watermark
%watermark -n -u -v -iv -w
pymc3 3.8
arviz 0.7.0
pandas 0.25.3
seaborn 0.9.0
numpy 1.17.5
last updated: Wed Apr 22 2020
CPython 3.8.0
IPython 7.11.0
watermark 2.0.2