description: Hamiltonian Monte Carlo, with given momentum distribution.
View source on GitHub |
Hamiltonian Monte Carlo, with given momentum distribution.
Inherits From: HamiltonianMonteCarlo
, TransitionKernel
tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn, step_size, num_leapfrog_steps, momentum_distribution=None,
state_gradients_are_stopped=False, step_size_update_fn=None,
store_parameters_in_results=False, experimental_shard_axis_names=None, name=None
)
See tfp.mcmc.HamiltonianMonteCarlo
for details on HMC.
HMC produces samples much more efficiently if properly preconditioned. This can be done by choosing a momentum distribution with covariance equal to the inverse of the state's covariance.
In this example we can use an estimate of the target covariance to sample efficiently with HMC.
import tensorflow as tf
import tensorflow_probability as tfp
tfed = tfp.experimental.distributions
# Suppose we have a target log prob fn, as well as an estimate of its
# covariance.
log_prob_fn = ...
cov_estimate = ...
# We want the mass matrix to be the *inverse* of the covariance estimate,
# so we can use the symmetric square root:
momentum_distribution = (
tfed.MultivariateNormalPrecisionFactorLinearOperator(
precision_factor=tf.linalg.LinearOperatorLowerTriangular(
tf.linalg.cholesky(cov_estimate),
),
precision=tf.linalg.LinearOperatorFullMatrix(cov_estimate),
)
# Run standard HMC below
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=log_prob_fn,
momentum_distribution=momentum_distribution,
step_size=0.3,
num_leapfrog_steps=10),
num_adaptation_steps=int(num_burnin_steps * 0.8))
@tf.function
def run_chain_and_compute_ess():
draws = tfp.mcmc.sample_chain(
num_results,
num_burnin_steps=num_burnin_steps,
current_state=tf.zeros(3), # 3 chains.
kernel=adaptive_hmc,
trace_fn=None)
return tfp.mcmc.effective_sample_size(draws, cross_chain_dims=1)
run_chain_and_compute_ess() # Something close to 3 x 1000.
This demonstrates using multiple state parts, and reshaping a
tfde.MultivariateNormalPrecisionFactorLinearOperator
to use with a scalar or a non-square shape (in this case, [2, 3, 4]
).
mvn = tfd.JointDistributionSequential([
tfd.Normal(0., 0.1),
tfd.Normal(0., 10.),
tfd.Independent(tfd.Normal(tf.fill([2, 3, 4], 3.), 10.),
reinterpreted_batch_ndims=3)])
reshape_to_scalar = tfp.bijectors.Reshape(event_shape_out=[])
reshape_to_234 = tfp.bijectors.Reshape(event_shape_out=[2, 3, 4])
momentum_distribution = tfd.JointDistributionSequential([
tfd.Normal(0., 10.),
tfd.Normal(0., 0.1),
reshape_to_234(
tfde.MultivariateNormalPrecisionFactorLinearOperator(
0., tf.linalg.LinearOperatorDiag(tf.fill([24], 10.))))
])
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=mvn.log_prob,
momentum_distribution=momentum_distribution,
step_size=0.3,
num_leapfrog_steps=10)
@tf.function
def run_chain_and_compute_ess():
draws = tfp.mcmc.sample_chain(
num_results,
num_burnin_steps=num_burnin_steps,
current_state=mvn.sample(),
kernel=adaptive_hmc,
trace_fn=None)
return tfp.mcmc.effective_sample_size(draws)
run_chain_and_compute_ess() # [1000, 1000, 1000 * tf.ones([2, 3, 4])]
Args | |
---|---|
target_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) log-density under the target distribution.
|
step_size
|
Tensor or Python list of Tensor s representing the step
size for the leapfrog integrator. Must broadcast with the shape of
current_state . Larger step sizes lead to faster progress, but
too-large step sizes make rejection exponentially more likely. When
possible, it's often helpful to match per-variable step sizes to the
standard deviations of the target distribution in each variable.
|
num_leapfrog_steps
|
Integer number of steps to run the leapfrog integrator
for. Total progress per HMC step is roughly proportional to
step_size * num_leapfrog_steps .
|
momentum_distribution
|
A tfp.distributions.Distribution instance to draw
momentum from. Defaults to normal distributions with identity
covariance.
|
state_gradients_are_stopped
|
Python bool indicating that the proposed
new state be run through tf.stop_gradient . This is particularly useful
when combining optimization over samples from the HMC chain.
Default value: False (i.e., do not apply stop_gradient ).
|
step_size_update_fn
|
Python callable taking current step_size
(typically a tf.Variable ) and kernel_results (typically
collections.namedtuple ) and returns updated step_size (Tensor s).
Default value: None (i.e., do not update step_size automatically).
|
store_parameters_in_results
|
If True , then step_size ,
momentum_distribution , and num_leapfrog_steps are written to and
read from eponymous fields in the kernel results objects returned from
one_step and bootstrap_results . This allows wrapper kernels to
adjust those parameters on the fly. In case this is True , the
momentum_distribution must be a CompositeTensor . See
tfp.experimental.as_composite and tfp.experimental.auto_composite .
This is incompatible with step_size_update_fn , which must be set to
None .
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'phmc_kernel').
|
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
num_leapfrog_steps
|
Returns the num_leapfrog_steps parameter.
If |
parameters
|
Return dict of __init__ arguments and their values.
|
state_gradients_are_stopped
|
|
step_size
|
Returns the step_size parameter.
If |
step_size_update_fn
|
|
target_log_prob_fn
|
bootstrap_results
bootstrap_results(
init_state
)
Creates initial previous_kernel_results
using a supplied state
.
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Runs one iteration of Hamiltonian Monte Carlo.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s). The first r dimensions index
independent chains, r = tf.rank(target_log_prob_fn(*current_state)) .
|
previous_kernel_results
|
collections.namedtuple containing Tensor s
representing values from previous calls to this function (or from the
bootstrap_results function.)
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as current_state .
|
kernel_results
|
collections.namedtuple of internal calculations used to
advance the chain.
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state .
|