flax.linen.GroupNorm#
- class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Group normalization (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.
- num_groups#
the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
- group_size#
the number of channels in a group.
- epsilon#
A small float added to variance to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_bias#
If True, bias (beta) is added.
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- bias_init#
Initializer for bias, by default, zero.
- scale_init#
Initializer for scale, by default, one.
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.