flax.linen.DenseGeneral#

class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A linear transformation with flexible axes.

features#

int or tuple with number of output features.

axis#

int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.

batch_dims#

tuple with batch axes.

use_bias#

whether to add a bias to the output (default: True).

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

precision#

numerical precision of the computation see jax.lax.Precision for details.

Parameters: