description: Lower bound on Jensen-Shannon (JS) divergence.

tfp.vi.mutual_information.lower_bound_jensen_shannon

Lower bound on Jensen-Shannon (JS) divergence.

This lower bound on JS divergence is proposed in [Goodfellow et al. (2014)][1] and [Nowozin et al. (2016)][2]. When estimating lower bounds on mutual information, one can also use different approaches for training the critic w.r.t. estimating mutual information [(Poole et al., 2018)][3]. The JS lower bound is used to train the critic with the standard lower bound on the Jensen-Shannon divergence as used in GANs, and then evaluates the critic using the NWJ lower bound on KL divergence, i.e. mutual information. As Eq.7 and Eq.8 of [Nowozin et al. (2016)][2], the bound is given by none I_JS = E_p(x,y)[log( D(x,y) )] + E_p(x)p(y)[log( 1 - D(x,y) )] where the first term is the expectation over the samples from joint distribution (positive samples), and the second is for the samples from marginal distributions (negative samples), with none D(x, y) = sigmoid(f(x, y)), log(D(x, y)) = softplus(-f(x, y)). f(x, y) is a critic function that scores all pairs of samples.

Example:

X, Y are samples from a joint Gaussian distribution, with correlation 0.8 and both of dimension 1.

batch_size, rho, dim = 10000, 0.8, 1
y, eps = tf.split(
    value=tf.random.normal(shape=(2 * batch_size, dim), seed=7),
    num_or_size_splits=2, axis=0)
mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho))
x = mean + conditional_stddev * eps

# Scores/unnormalized likelihood of pairs of samples `x[i], y[j]`
# (For JS lower bound, the optimal critic is of the form `f(x, y) = 1 +
# log(p(x | y) / p(x))` [(Poole et al., 2018)][3].)
conditional_dist = tfd.MultivariateNormalDiag(
    mean, scale_identity_multiplier=conditional_stddev)
conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :])
marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim))
marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis]
scores = 1 + conditional_scores - marginal_scores

# Mask for joint samples in the score tensor
# (The `scores` has its shape [x_batch_size, y_batch_size], i.e.
# `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.)
joint_sample_mask = tf.eye(batch_size, dtype=bool)

# Lower bound on Jensen Shannon divergence
lower_bound_jensen_shannon(logu=scores, joint_sample_mask=joint_sample_mask)

logu float-like Tensor of size [batch_size_1, batch_size_2] representing critic scores (scores) for pairs of points (x, y) with logu[i, j] = f(x[i], y[j]).
joint_sample_mask bool-like Tensor of the same size as logu masking the positive samples by True, i.e. samples from joint distribution p(x, y). Default value: None. By default, an identity matrix is constructed as the mask.
validate_args Python bool, default False. Whether to validate input with asserts. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'lower_bound_jensen_shannon').

lower_bound float-like scalar for lower bound on JS divergence.

References:

[1]: Ian J. Goodfellow, et al. Generative Adversarial Nets. In Conference on Neural Information Processing Systems, 2014. https://arxiv.org/abs/1406.2661. [2]: Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization. In Conference on Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.00709. [3]: Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi, George Tucker. On Variational Bounds of Mutual Information. In International Conference on Machine Learning, 2019. https://arxiv.org/abs/1905.06922.