description: Adapt and sample from a joint distribution using NUTS, conditioned on pins.
View source on GitHub |
Adapt and sample from a joint distribution using NUTS, conditioned on pins.
tfp.experimental.mcmc.windowed_adaptive_nuts(
n_draws, joint_dist, *, n_chains=64, num_adaptation_steps=500,
current_state=None, init_step_size=None, dual_averaging_kwargs=None,
max_tree_depth=10, max_energy_diff=500.0, unrolled_leapfrog_steps=1,
parallel_iterations=10, trace_fn=_default_nuts_trace_fn,
return_final_kernel_results=False, discard_tuning=True, chain_axis_names=None,
seed=None, **pins
)
Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows.
Args | |
---|---|
n_draws
|
int Number of draws after adaptation. |
joint_dist
|
tfd.JointDistribution
A joint distribution to sample from.
|
n_chains
|
int or list of ints Number of independent chains to run MCMC with. |
num_adaptation_steps
|
int Number of draws used to adapt step size and mass matrix. |
current_state
|
Optional
Structure of tensors at which to initialize sampling. Should have the
same shape and structure as
model.experimental_pin(**pins).sample(n_chains) .
|
init_step_size
|
Optional
Where to initialize the step size for the leapfrog integrator. The
structure should broadcast with current_state . For example, if the
initial state is
{'a': tf.zeros(n_chains),
'b': tf.zeros([n_chains, n_features])}
then any of 1. , {'a': 1., 'b': 1.} , or
{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])} will
work. Defaults to the dimension of the log density to the 0.25 power.
|
dual_averaging_kwargs
|
Optional dict
Keyword arguments to pass to tfp.mcmc.DualAveragingStepSizeAdaptation .
By default, a target_accept_prob of 0.85 is set, acceptance
probabilities across chains are reduced using a harmonic mean, and the
class defaults are used otherwise.
|
max_tree_depth
|
Maximum depth of the tree implicitly built by NUTS. The
maximum number of leapfrog steps is bounded by 2**max_tree_depth i.e.
the number of nodes in a binary tree max_tree_depth nodes deep. The
default setting of 10 takes up to 1024 leapfrog steps.
|
max_energy_diff
|
Scalar threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. |
unrolled_leapfrog_steps
|
The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. |
parallel_iterations
|
The number of iterations allowed to run in parallel.
It must be a positive integer. See tf.while_loop for more details.
|
trace_fn
|
Optional callable
The trace function should accept the arguments
(state, bijector, is_adapting, phmc_kernel_results) , where the state
is an unconstrained, flattened float tensor, bijector is the
tfb.Bijector that is used for unconstraining and flattening,
is_adapting is a boolean to mark whether the draw is from an adaptation
step, and phmc_kernel_results is the
UncalibratedPreconditionedHamiltonianMonteCarloKernelResults from the
PreconditionedHamiltonianMonteCarlo kernel. Note that
bijector.inverse(state) will provide access to the current draw in the
untransformed space, using the structure of the provided joint_dist .
|
return_final_kernel_results
|
If True , then the final kernel results are
returned alongside the chain state and the trace specified by the
trace_fn .
|
discard_tuning
|
bool Whether to return tuning traces and draws. |
chain_axis_names
|
A str or list of str s indicating the named axes
by which multiple chains are sharded. See tfp.experimental.mcmc.Sharded
for more context.
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
**pins
|
These are used to condition the provided joint distribution, and are
passed directly to joint_dist.experimental_pin(**pins) .
|
Returns | |
---|---|
A single structure of draws is returned in case the trace_fn is None , and
return_final_kernel_results is False . If there is a trace function,
the return value is a tuple, with the trace second. If the
return_final_kernel_results is True , the return value is a tuple of
length 3, with final kernel results returned last. If discard_tuning is
True , the tensors in draws and trace will have length n_draws ,
otherwise, they will have length n_draws + num_adaptation_steps .
|