flax.linen.BatchNorm#
- class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
BatchNorm Module.
Usage Note: If we define a model with BatchNorm, for example:
BN = nn.BatchNorm(use_running_average=False, momentum=0.9, epsilon=1e-5, dtype=jnp.float32)
The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:
vars_initialized = BN.init(key, x) # {'params': ..., 'batch_stats': ...}
We then update the batch_stats during training by specifying that the batch_stats collection is mutable in the apply method for our module.:
vars_in = {'params': params, 'batch_stats': old_batch_stats} y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats']) new_batch_stats = mutated_vars['batch_stats']
During eval we would define BN with use_running_average=True and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:
vars_in = {'params': params, 'batch_stats': training_batch_stats} y = BN.apply(vars_in, x)
- use_running_average#
if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
- axis#
the feature or non-batch axis of the input.
- momentum#
decay rate for the exponential moving average of the batch statistics.
- epsilon#
a small float added to variance to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_bias#
if True, bias (beta) is added.
- use_scale#
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- bias_init#
initializer for bias, by default, zero.
- scale_init#
initializer for scale, by default, one.
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.