flax.linen.Dense#
- class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A linear transformation applied over the last dimension of the input.
- features#
the number of output features.
- use_bias#
whether to add a bias to the output (default: True).
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- kernel_init#
initializer function for the weight matrix.
- bias_init#
initializer function for the bias.