description: Runs one step of Uncalibrated Langevin discretized diffusion.
View source on GitHub |
Runs one step of Uncalibrated Langevin discretized diffusion.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.UncalibratedLangevin(
target_log_prob_fn, step_size, volatility_fn=None, parallel_iterations=10,
compute_acceptance=True, experimental_shard_axis_names=None, name=None
)
The class generates a Langevin proposal using _euler_method
function and
also computes helper UncalibratedLangevinKernelResults
for the next
iteration.
Warning: this kernel will not result in a chain which converges to the
target_log_prob
. To get a convergent MCMC, use
MetropolisAdjustedLangevinAlgorithm(...)
or
MetropolisHastings(UncalibratedLangevin(...))
.
For more details on UncalibratedLangevin
, see
MetropolisAdjustedLangevinAlgorithm
.
Args | |
---|---|
target_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) log-density under the target distribution.
|
step_size
|
Tensor or Python list of Tensor s representing the step
size for the leapfrog integrator. Must broadcast with the shape of
current_state . Larger step sizes lead to faster progress, but
too-large step sizes make rejection exponentially more likely. When
possible, it's often helpful to match per-variable step sizes to the
standard deviations of the target distribution in each variable.
|
volatility_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns
volatility value at current_state . Should return a Tensor or Python
list of Tensor s that must broadcast with the shape of
current_state Defaults to the identity function.
|
parallel_iterations
|
the number of coordinates for which the gradients of
the volatility matrix volatility_fn can be computed in parallel.
|
compute_acceptance
|
Python 'bool' indicating whether to compute the
Metropolis log-acceptance ratio used to construct
MetropolisAdjustedLangevinAlgorithm kernel.
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'mala_kernel').
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state .
|
TypeError
|
if volatility_fn is not callable.
|
Attributes | |
---|---|
compute_acceptance
|
|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
parallel_iterations
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
step_size
|
|
target_log_prob_fn
|
|
volatility_fn
|
bootstrap_results
bootstrap_results(
init_state
)
Creates initial previous_kernel_results
using a supplied state
.
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Runs one iteration of MALA.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s). The first r dimensions index
independent chains, r = tf.rank(target_log_prob_fn(*current_state)) .
|
previous_kernel_results
|
collections.namedtuple containing Tensor s
representing values from previous calls to this function (or from the
bootstrap_results function.)
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as current_state .
|
kernel_results
|
collections.namedtuple of internal calculations used to
advance the chain.
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state or diffusion_drift .
|