flax.linen.SpectralNorm#

class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Spectral normalization.

See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096

Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its __call__ output.

Usage Note: The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a u vector and sigma value, which are intermediate values used when performing spectral normalization. During training, we pass in update_stats=True and mutable=['batch_stats'] so that u and sigma are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass in update_stats=False to ensure we get deterministic behavior from the model. For example:

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    x = nn.Dense(3)(x)
    # only spectral normalize the params of the second Dense layer
    x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train)
    x = nn.Dense(5)(x)
    return x

# init
x = jnp.ones((1, 2))
y = jnp.ones((1, 5))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x, train=False)

# train
def train_step(variables, x, y):
  def loss_fn(params):
    logits, updates = model.apply(
        {'params': params, 'batch_stats': variables['batch_stats']},
        x,
        train=True,
        mutable=['batch_stats'],
    )
    loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y))
    return loss, updates

  (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)(
      variables['params']
  )
  return {
      'params': jax.tree_map(
          lambda p, g: p - 0.1 * g, variables['params'], grads
      ),
      'batch_stats': updates['batch_stats'],
  }, loss
for _ in range(10):
  variables, loss = train_step(variables, x, y)

# inference / eval
out = model.apply(variables, x, train=False)
layer_instance#

Module instance that is wrapped with SpectralNorm

n_steps#

How many steps of power iteration to perform to approximate the singular value of the weight params.

epsilon#

A small float added to l2-normalization to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

error_on_non_matrix#

Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.

collection_name#

Name of the collection to store intermediate values used when performing spectral normalization.

Parameters: