description: Use gradient ascent to adapt inner kernel's trajectory length.
View source on GitHub |
Use gradient ascent to adapt inner kernel's trajectory length.
Inherits From: TransitionKernel
tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
inner_kernel, num_adaptation_steps, use_halton_sequence_jitter=True,
adaptation_rate=0.025, jitter_amount=1.0,
criterion_fn=tfp.experimental.mcmc.chees_criterion, max_leapfrog_steps=1000,
num_leapfrog_steps_getter_fn=hmc_like_num_leapfrog_steps_getter_fn,
num_leapfrog_steps_setter_fn=hmc_like_num_leapfrog_steps_setter_fn,
step_size_getter_fn=hmc_like_step_size_getter_fn,
proposed_velocity_getter_fn=hmc_like_proposed_velocity_getter_fn,
log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn,
proposed_state_getter_fn=hmc_like_proposed_state_getter_fn, validate_args=False,
experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None,
name=None
)
This kernel optimizes the continuous trajectory length (aka integration time)
parameter of Hamiltonian Monte Carlo. It does so by following the gradient of
a criterion with respect to the trajectory length. The criterion is computed
via criterion_fn
with signature (previous_state, proposed_state,
accept_prob) -> criterion
, where both the returned values retain the batch
dimensions implied by the first three inputs. See chees_criterion
for an
example.
To avoid resonances, this kernel jitters the integration time between 0 and the learned trajectory length by default.
The initial trajectory length is extracted from the inner
HamiltonianMonteCarlo
kernel by multiplying the initial step size and
initial number of leapfrog steps. This (and other algorithmic details) imply
that the step size must be a scalar.
In general, adaptation prevents the chain from reaching a stationary
distribution, so obtaining consistent samples requires num_adaptation_steps
be set to a value somewhat smaller than the number of burnin steps.
However, it may sometimes be helpful to set num_adaptation_steps
to a larger
value during development in order to inspect the behavior of the chain during
adaptation.
This implements something similar to ChEES HMC from [2].
import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
target_log_prob_fn = tfd.JointDistributionSequential([
tfd.Normal(0., 20.),
tfd.HalfNormal(10.),
]).log_prob
num_burnin_steps = 1000
num_adaptation_steps = int(num_burnin_steps * 0.8)
num_results = 500
num_chains = 16
step_size = 0.1
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size,
num_leapfrog_steps=1,
)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=num_adaptation_steps)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
kernel, num_adaptation_steps=num_adaptation_steps)
kernel = tfp.mcmc.TransformedTransitionKernel(
kernel,
[tfb.Identity(),
tfb.Exp()])
def trace_fn(_, pkr):
return (
pkr.inner_results.inner_results.inner_results.accepted_results
.step_size,
pkr.inner_results.inner_results.max_trajectory_length,
pkr.inner_results.inner_results.inner_results.log_accept_ratio,
)
# The chain will be stepped for num_results + num_burnin_steps, adapting for
# the first num_adaptation_steps.
samples, [step_size, max_trajectory_length, log_accept_ratio] = (
tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=[tf.zeros(num_chains),
tf.zeros(num_chains)],
kernel=kernel,
trace_fn=trace_fn,))
# ~0.75
accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp(
tf.minimum(log_accept_ratio, 0.)))
burn-vs-warm-iterative-simulation-algorithms/#comment-627745>
[2]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In preparation.
Args | |
---|---|
inner_kernel
|
TransitionKernel -like object.
|
num_adaptation_steps
|
Scalar int Tensor number of initial steps to
during which to adjust the step size. This may be greater, less than, or
equal to the number of burnin steps.
|
use_halton_sequence_jitter
|
Python bool. Whether to use a Halton sequence for jittering the trajectory length. This makes the procedure more stable than sampling trajectory lengths from a uniform distribution. |
adaptation_rate
|
Floating point scalar Tensor . How rapidly to adapt the
trajectory length.
|
jitter_amount
|
Floating point scalar Tensor . How much to jitter the
trajectory on the next step. The trajectory length is sampled from [(1
- jitter_amount) * max_trajectory_length, max_trajectory_length] .
|
criterion_fn
|
Callable with (previous_state, proposed_state, accept_prob)
-> criterion . Computes the criterion value.
|
max_leapfrog_steps
|
Int32 scalar Tensor . Clips the number of leapfrog
steps to this value.
|
num_leapfrog_steps_getter_fn
|
A callable with the signature
(kernel_results) -> num_leapfrog_steps where kernel_results are the
results of the inner_kernel , and num_leapfrog_steps is a floating
point Tensor .
|
num_leapfrog_steps_setter_fn
|
A callable with the signature
(kernel_results, new_num_leapfrog_steps) -> new_kernel_results where
kernel_results are the results of the inner_kernel ,
new_num_leapfrog_steps is a scalar tensor Tensor , and
new_kernel_results are a copy of kernel_results with the number of
leapfrog steps set.
|
step_size_getter_fn
|
A callable with the signature (kernel_results) ->
step_size where kernel_results are the results of the inner_kernel ,
and step_size is a floating point Tensor .
|
proposed_velocity_getter_fn
|
A callable with the signature
(kernel_results) -> proposed_velocity where kernel_results are the
results of the inner_kernel , and proposed_velocity is a (possibly
nested) floating point Tensor . Velocity is derivative of state with
respect to trajectory length.
|
log_accept_prob_getter_fn
|
A callable with the signature (kernel_results)
-> log_accept_prob where kernel_results are the results of the
inner_kernel , and log_accept_prob is a floating point Tensor .
log_accept_prob has shape [C0, ...., Cb] with b > 0 .
|
proposed_state_getter_fn
|
A callable with the signature (kernel_results)
-> proposed_state where kernel_results are the results of the
inner_kernel , and proposed_state is a (possibly nested) floating
point Tensor .
|
validate_args
|
Python bool . When True kernel parameters are checked
for validity. When False invalid inputs may silently render incorrect
outputs.
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
experimental_reduce_chain_axis_names
|
A string or list of string names indicating how batches of chains are sharded. |
name
|
Python str name prefixed to Ops created by this class. Default:
'simple_step_size_adaptation'.
|
Raises | |
---|---|
ValueError
|
If inner_kernel contains a TransformedTransitionKernel in
its hierarchy. If you need to use the TransformedTransitionKernel ,
place it above this kernel in the hierarchy (see the example in the
class docstring).
|
Attributes | |
---|---|
experimental_reduce_chain_axis_names
|
|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
inner_kernel
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
max_leapfrog_steps
|
|
name
|
|
num_adaptation_steps
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
use_halton_sequence_jitter
|
|
validate_args
|
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step(...)[1]
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
criterion_fn
criterion_fn(
previous_state, proposed_state, accept_prob
)
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
log_accept_prob_getter_fn
log_accept_prob_getter_fn(
kernel_results
)
num_leapfrog_steps_getter_fn
num_leapfrog_steps_getter_fn(
kernel_results
)
num_leapfrog_steps_setter_fn
num_leapfrog_steps_setter_fn(
kernel_results, new_num_leapfrog_steps
)
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Must be overridden by subclasses.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
proposed_state_getter_fn
proposed_state_getter_fn(
kernel_results
)
proposed_velocity_getter_fn
proposed_velocity_getter_fn(
kernel_results
)
step_size_getter_fn
step_size_getter_fn(
kernel_results
)