description: Lower bound on Jensen-Shannon (JS) divergence.
View source on GitHub |
Lower bound on Jensen-Shannon (JS) divergence.
tfp.vi.mutual_information.lower_bound_jensen_shannon(
logu, joint_sample_mask=None, validate_args=False, name=None
)
This lower bound on JS divergence is proposed in
[Goodfellow et al. (2014)][1] and [Nowozin et al. (2016)][2].
When estimating lower bounds on mutual information, one can also use
different approaches for training the critic w.r.t. estimating
mutual information [(Poole et al., 2018)][3]. The JS lower bound is
used to train the critic with the standard lower bound on the
Jensen-Shannon divergence as used in GANs, and then evaluates the
critic using the NWJ lower bound on KL divergence, i.e. mutual information.
As Eq.7 and Eq.8 of [Nowozin et al. (2016)][2], the bound is given by
none
I_JS = E_p(x,y)[log( D(x,y) )] + E_p(x)p(y)[log( 1 - D(x,y) )]
where the first term is the expectation over the samples from joint
distribution (positive samples), and the second is for the samples
from marginal distributions (negative samples), with
none
D(x, y) = sigmoid(f(x, y)),
log(D(x, y)) = softplus(-f(x, y)).
f(x, y)
is a critic function that scores all pairs of samples.
X
, Y
are samples from a joint Gaussian distribution, with
correlation 0.8
and both of dimension 1
.
batch_size, rho, dim = 10000, 0.8, 1
y, eps = tf.split(
value=tf.random.normal(shape=(2 * batch_size, dim), seed=7),
num_or_size_splits=2, axis=0)
mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho))
x = mean + conditional_stddev * eps
# Scores/unnormalized likelihood of pairs of samples `x[i], y[j]`
# (For JS lower bound, the optimal critic is of the form `f(x, y) = 1 +
# log(p(x | y) / p(x))` [(Poole et al., 2018)][3].)
conditional_dist = tfd.MultivariateNormalDiag(
mean, scale_identity_multiplier=conditional_stddev)
conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :])
marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim))
marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis]
scores = 1 + conditional_scores - marginal_scores
# Mask for joint samples in the score tensor
# (The `scores` has its shape [x_batch_size, y_batch_size], i.e.
# `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.)
joint_sample_mask = tf.eye(batch_size, dtype=bool)
# Lower bound on Jensen Shannon divergence
lower_bound_jensen_shannon(logu=scores, joint_sample_mask=joint_sample_mask)
Args | |
---|---|
logu
|
float -like Tensor of size [batch_size_1, batch_size_2]
representing critic scores (scores) for pairs of points (x, y) with
logu[i, j] = f(x[i], y[j]) .
|
joint_sample_mask
|
bool -like Tensor of the same size as logu
masking the positive samples by True , i.e. samples from joint
distribution p(x, y) .
Default value: None . By default, an identity matrix is constructed as
the mask.
|
validate_args
|
Python bool , default False . Whether to validate input
with asserts. If validate_args is False , and the inputs are invalid,
correct behavior is not guaranteed.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'lower_bound_jensen_shannon').
|
Returns | |
---|---|
lower_bound
|
float -like scalar for lower bound on JS divergence.
|
[1]: Ian J. Goodfellow, et al. Generative Adversarial Nets. In Conference on Neural Information Processing Systems, 2014. https://arxiv.org/abs/1406.2661. [2]: Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization. In Conference on Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.00709. [3]: Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi, George Tucker. On Variational Bounds of Mutual Information. In International Conference on Machine Learning, 2019. https://arxiv.org/abs/1905.06922.