flax.linen.SpectralNorm#
- class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Spectral normalization.
See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096
Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its
__call__
output.Usage Note: The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a
u
vector andsigma
value, which are intermediate values used when performing spectral normalization. During training, we pass inupdate_stats=True
andmutable=['batch_stats']
so thatu
andsigma
are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass inupdate_stats=False
to ensure we get deterministic behavior from the model. For example:class Foo(nn.Module): @nn.compact def __call__(self, x, train): x = nn.Dense(3)(x) # only spectral normalize the params of the second Dense layer x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train) x = nn.Dense(5)(x) return x # init x = jnp.ones((1, 2)) y = jnp.ones((1, 5)) model = Foo() variables = model.init(jax.random.PRNGKey(0), x, train=False) # train def train_step(variables, x, y): def loss_fn(params): logits, updates = model.apply( {'params': params, 'batch_stats': variables['batch_stats']}, x, train=True, mutable=['batch_stats'], ) loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y)) return loss, updates (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)( variables['params'] ) return { 'params': jax.tree_map( lambda p, g: p - 0.1 * g, variables['params'], grads ), 'batch_stats': updates['batch_stats'], }, loss for _ in range(10): variables, loss = train_step(variables, x, y) # inference / eval out = model.apply(variables, x, train=False)
- layer_instance#
Module instance that is wrapped with SpectralNorm
- n_steps#
How many steps of power iteration to perform to approximate the singular value of the weight params.
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- error_on_non_matrix#
Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.
- collection_name#
Name of the collection to store intermediate values used when performing spectral normalization.