flax.linen.Dropout#
- class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Create a dropout layer.
Note: When using
Module.apply()
, make sure to include an RNG seed named ‘dropout’. Dropout isn’t necessary for variable initialization. Example:class MLP(nn.Module): @nn.compact def __call__(self, x, train): x = nn.Dense(4)(x) x = nn.Dropout(0.5, deterministic=not train)(x) return x model = MLP() x = jnp.ones((1, 3)) variables = model.init(jax.random.key(0), x, train=False) # don't use dropout model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
- rate#
the dropout probability. (_not_ the keep rate!)
- broadcast_dims#
dimensions that will share the same dropout mask
- deterministic#
if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.
- rng_collection#
the rng collection name to use when requesting an rng key.