description: Returns a CompositeTensor equivalent to the given object.

tfp.experimental.as_composite

Returns a CompositeTensor equivalent to the given object.

Note that the returned object will have any Variable, tfp.util.DeferredTensor, or tfp.util.TransformedVariable references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both CompositeTensor and type(obj). For this reason, one should be careful about using as_composite(), especially for tf.Module objects.

For example, when the composite tensor is created even as part of a tf.Module, it "fixes" the values of the DeferredTensor and tf.Variable objects it uses:

class M(tf.Module):
  def __init__(self):
    self._v = tf.Variable(1.)
    self._d = tfp.distributions.Normal(
      tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)
    self._dct = tfp.experimental.as_composite(self._d)

  @tf.function
  def mean(self):
    return self._dct.mean()

m = M()
m.mean()

m._v.assign(2.) # Doesn't update the CompositeTensor distribution. m.mean()

If, however, the creation of the composite is deferred to a method call, then the Variable and DeferredTensor will be properly captured and respected by the Module and its SavedModel (if it is serialized).

class M(tf.Module):
  def __init__(self):
    self._v = tf.Variable(1.)
    self._d = tfp.distributions.Normal(
      tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)

  @tf.function
  def d(self):
    return tfp.experimental.as_composite(self._d)

m = M()
m.d().mean()

m._v.assign(2.) m.d().mean()

Note: This method is best-effort and based on a heuristic for what the tensor parameters are and what the non-tensor parameters are. Things might be broken, especially for meta-distributions like TransformedDistribution or Independent. (We try to raise NotImplementedError in such cases.) If you'd benefit from better coverage, please file an issue on github or send an email to tfprobability@tensorflow.org.

obj A tfp.distributions.Distribution.

obj A tfp.distributions.Distribution that extends CompositeTensor.