flax.linen.WeightNorm#
- class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf).
Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its
__call__
output.Example:
class Baz(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(2)(x) class Bar(nn.Module): @nn.compact def __call__(self, x): x = Baz()(x) x = nn.Dense(3)(x) x = Baz()(x) x = nn.Dense(3)(x) return x class Foo(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(3)(x) # l2-normalize all params of the second Dense layer x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) x = nn.Dense(5)(x) # l2-normalize all kernels in the Bar submodule and all params in the # Baz submodule x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) return x # init x = jnp.ones((1, 2)) model = Foo() variables = model.init(jax.random.key(0), x) variables # { # params: { # ... # WeightNorm_0: { # Dense_1/bias/scale: Array([1., 1., 1., 1.], dtype=float32), # Dense_1/kernel/scale: Array([1., 1., 1., 1.], dtype=float32), # }, # ... # WeightNorm_1: { # Bar_0/Baz_0/Dense_0/bias/scale: Array([1., 1.], dtype=float32), # Bar_0/Baz_0/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), # Bar_0/Baz_1/Dense_0/bias/scale: Array([1., 1.], dtype=float32), # Bar_0/Baz_1/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), # Bar_0/Dense_0/kernel/scale: Array([1., 1., 1.], dtype=float32), # Bar_0/Dense_1/kernel/scale: Array([1., 1., 1.], dtype=float32), # }, # ... # } # }
- layer_instance#
Module instance that is wrapped with WeightNorm
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_scale#
If True, creates a learnable variable
scale
that is multiplied to thelayer_instance
variables after l2-normalization.
- scale_init#
Initialization function for the scaling function.
- feature_axes#
The feature axes dimension(s). The l2-norm is calculated by reducing the
layer_instance
variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (ifuse_scale=True
) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.
- variable_filter#
An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the
layer_instance
variables whose key path (delimited by ‘/’) has a match withvariable_filter
. For example,variable_filter={'kernel'}
will only apply l2-normalization to variables whose key path contains ‘kernel’. By default,variable_filter={'kernel'}
.