flax.linen.WeightNorm#

class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf).

Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its __call__ output.

Example:

class Baz(nn.Module):
  @nn.compact
  def __call__(self, x):
    return nn.Dense(2)(x)

class Bar(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = Baz()(x)
    x = nn.Dense(3)(x)
    x = Baz()(x)
    x = nn.Dense(3)(x)
    return x

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(3)(x)
    # l2-normalize all params of the second Dense layer
    x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x)
    x = nn.Dense(5)(x)
    # l2-normalize all kernels in the Bar submodule and all params in the
    # Baz submodule
    x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x)
    return x

# init
x = jnp.ones((1, 2))
model = Foo()
variables = model.init(jax.random.key(0), x)

variables
# {
#   params: {
#     ...
#     WeightNorm_0: {
#         Dense_1/bias/scale: Array([1., 1., 1., 1.], dtype=float32),
#         Dense_1/kernel/scale: Array([1., 1., 1., 1.], dtype=float32),
#     },
#     ...
#     WeightNorm_1: {
#         Bar_0/Baz_0/Dense_0/bias/scale: Array([1., 1.], dtype=float32),
#         Bar_0/Baz_0/Dense_0/kernel/scale: Array([1., 1.], dtype=float32),
#         Bar_0/Baz_1/Dense_0/bias/scale: Array([1., 1.], dtype=float32),
#         Bar_0/Baz_1/Dense_0/kernel/scale: Array([1., 1.], dtype=float32),
#         Bar_0/Dense_0/kernel/scale: Array([1., 1., 1.], dtype=float32),
#         Bar_0/Dense_1/kernel/scale: Array([1., 1., 1.], dtype=float32),
#     },
#     ...
#   }
# }
layer_instance#

Module instance that is wrapped with WeightNorm

epsilon#

A small float added to l2-normalization to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_scale#

If True, creates a learnable variable scale that is multiplied to the layer_instance variables after l2-normalization.

scale_init#

Initialization function for the scaling function.

feature_axes#

The feature axes dimension(s). The l2-norm is calculated by reducing the layer_instance variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (if use_scale=True) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.

variable_filter#

An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the layer_instance variables whose key path (delimited by ‘/’) has a match with variable_filter. For example, variable_filter={'kernel'} will only apply l2-normalization to variables whose key path contains ‘kernel’. By default, variable_filter={'kernel'}.

Parameters: