description: Registers a rule for substituting distributions in ASVI surrogates.
View source on GitHub |
Registers a rule for substituting distributions in ASVI surrogates.
tfp.experimental.vi.register_asvi_substitution_rule(
condition, substitution_fn
)
Args | |
---|---|
condition
|
Python callable that takes a Distribution instance and
returns a Python bool indicating whether or not to substitute it.
May also be a class type such as tfd.Normal , in which case the
condition is interpreted as
lambda distribution: isinstance(distribution, class) .
|
substitution_fn
|
Python callable that takes a Distribution
instance and returns a new Distribution instance used to define
the ASVI surrogate posterior. Note that this substitution does not modify
the original model.
|
To use a Normal surrogate for all location-scale family distributions, we could register the substitution:
tfp.experimental.vi.register_asvi_surrogate_substitution(
condition=lambda distribution: (
hasattr(distribution, 'loc') and hasattr(distribution, 'scale'))
substitution_fn=lambda distribution: (
# Invoking the event space bijector applies any relevant constraints,
# e.g., that HalfCauchy samples must be `>= loc`.
distribution.experimental_default_event_space_bijector()(
tfd.Normal(loc=distribution.loc, scale=distribution.scale)))
This rule will fire when ASVI encounters a location-scale distribution,
and instructs ASVI to build a surrogate 'as if' the model had just used a
(possibly constrained) Normal in its place. Note that we could have used a
more precise condition, e.g., to limit the substitution to distributions with
a specific name
, if we had reason to think that a Normal distribution would
be a good surrogate for some model variables but not others.