description: Local PRNG for amplifying seed entropy into seeds for base operations.

tfp.substrates.jax.util.SeedStream

Local PRNG for amplifying seed entropy into seeds for base operations.

Writing sampling code which correctly sets the pseudo-random number generator (PRNG) seed is surprisingly difficult. This class serves as a helper for the TensorFlow Probability coding pattern designed to avoid common mistakes.

Motivating Example

A common first-cut implementation of a sampler for the beta distribution is to compute the ratio of a gamma with itself plus another gamma. This code snippet tries to do that, but contains a surprisingly common error:

def broken_beta(shape, alpha, beta, seed):
  x = tf.random.gamma(shape, alpha, seed=seed)
  y = tf.random.gamma(shape, beta, seed=seed)
  return x / (x + y)

The mistake is that the two gamma draws are seeded with the same seed. This causes them to always produce the same results, which, in turn, leads this code snippet to always return 0.5. Because it can happen across abstraction boundaries, this kind of error is surprisingly easy to make when handling immutable seeds.

Goals

TensorFlow Probability adopts a code style designed to eliminate the above class of error, without exacerbating others. The goals of this code style are:

Non-goals

Code pattern

def random_beta(shape, alpha, beta, seed):        # (a)
  seed = SeedStream(seed, salt="random_beta")     # (b)
  x = tf.random.gamma(shape, alpha, seed=seed())  # (c)
  y = tf.random.gamma(shape, beta, seed=seed())   # (c)
  return x / (x + y)

The elements of this pattern are:

Why salt?

Salting the SeedStream instances (with unique salts) is defensive programming against a client accidentally committing a mistake similar to our motivating example. Consider the following situation that might arise without salting:

def tfp_foo(seed):
  seed = SeedStream(seed, salt="")
  foo_stuff = tf.random.stateless_normal(seed=seed())
  ...

def tfp_bar(seed):
  seed = SeedStream(seed, salt="")
  bar_stuff = tf.random.stateless_normal(seed=seed())
  ...

def client_baz(seed):
  foo = tfp_foo(seed=seed)
  bar = tfp_bar(seed=seed)
  ...

The client should have used different seeds as inputs to foo and bar. However, because they didn't, and because foo and bar both sample a Gaussian internally as their first action, the internal foo_stuff and bar_stuff will be the same, and the returned foo and bar will not be independent, leading to subtly incorrect answers from the client's simulation. This kind of bug is particularly insidious for the client, because it depends on a Distributions implementation detail, namely the order in which foo and bar invoke the samplers they depend on. In particular, a Bayesflow team member can introduce such a bug in previously (accidentally) correct client code by performing an internal refactoring that causes this operation order alignment.

A salting discipline eliminates this problem by making sure that the seeds seen by foo's callees will differ from those seen by bar's callees, even if foo and bar are invoked with the same input seed.

seed Any Python object convertible to string, with the exception of Tensor objects, which are reserved for stateless sampling semantics. The seed supplies the initial entropy. If None, operations seeded with seeds drawn from this SeedStream will follow TensorFlow semantics for not being seeded.
salt Any Python object convertible to string, supplying auxiliary entropy. Must be unique across the Distributions and TensorFlow Probability code base. See class docstring for rationale.

original_seed

salt

Methods

__call__

View source

A call impl for JAX PRNGKeys.