Layers#

Linear Modules#

Pooling#

flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

Pools the input by taking the maximum of a window slice.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).

Returns:

The maximum for each window slice.

flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

Pools the input by taking the average over a window.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).

  • count_include_pad – a boolean whether to include padded tokens in the average calculation (default: True).

Returns:

The average for each window slice.

flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#

Helper function to define pooling functions.

Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • init – the initial value for the reduction

  • reduce_fn – a reduce function of the form (T, T) -> T.

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

Returns:

The output of the reduction for each window slice.

Normalization#

Combinators#

Stochastic#

Attention#

flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#

Computes dot-product attention weights given query and key.

Used by dot_product_attention(), which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.

Parameters:
  • query (Any) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].

  • key (Any) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].

  • bias (Optional[Any]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask (Optional[Any]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout (bool) – bool: use a broadcasted dropout along batch dims.

  • dropout_rng (Optional[Array]) – JAX PRNGKey: to be used for dropout

  • dropout_rate (float) – dropout rate

  • deterministic (bool) – bool, deterministic or not (to apply dropout)

  • dtype (Optional[Any]) – the dtype of the computation (default: infer from inputs and params)

  • precision (Union[None, str, Precision, Tuple[str, str], Tuple[Precision, Precision]]) – numerical precision of the computation see jax.lax.Precision for details.

Returns:

Output of shape [batch…, num_heads, q_length, kv_length].

flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#

Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.

Note: query, key, value needn’t have any batch dimensions.

Parameters:
  • query (Any) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].

  • key (Any) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].

  • value (Any) – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].

  • bias (Optional[Any]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask (Optional[Any]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout (bool) – bool: use a broadcasted dropout along batch dims.

  • dropout_rng (Optional[Array]) – JAX PRNGKey: to be used for dropout

  • dropout_rate (float) – dropout rate

  • deterministic (bool) – bool, deterministic or not (to apply dropout)

  • dtype (Optional[Any]) – the dtype of the computation (default: infer from inputs)

  • precision (Union[None, str, Precision, Tuple[str, str], Tuple[Precision, Precision]]) – numerical precision of the computation see jax.lax.Precision for details.

Returns:

Output of shape [batch…, q_length, num_heads, v_depth_per_head].

flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

Parameters:
  • query_input (Any) – a batched, flat input of query_length size

  • key_input (Any) – a batched, flat input of key_length size

  • pairwise_fn (Callable[..., Any]) – broadcasting elementwise comparison function

  • extra_batch_dims (int) – number of extra batch dims to add singleton axes for, none by default

  • dtype (Any) – mask return dtype

Returns:

A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

Parameters:
  • x (Any) – input array of shape [batch…, len]

  • extra_batch_dims (int) – number of batch dims to add singleton axes for, none by default

  • dtype (Any) – mask return dtype

Return type:

Any

Returns:

A [batch…, 1, len, len] shaped causal mask for 1d attention.

Recurrent#

Summary

Dense(features[, use_bias, dtype, ...])

A linear transformation applied over the last dimension of the input.

DenseGeneral(features[, axis, batch_dims, ...])

A linear transformation with flexible axes.

Conv(features, kernel_size[, strides, ...])

Convolution Module wrapping lax.conv_general_dilated.

ConvTranspose(features, kernel_size[, ...])

Convolution Module wrapping lax.conv_transpose.

ConvLocal(features, kernel_size[, strides, ...])

Local convolution Module wrapping lax.conv_general_dilated_local.

Embed(num_embeddings, features[, dtype, ...])

Embedding Module.

BatchNorm([use_running_average, axis, ...])

BatchNorm Module.

LayerNorm([epsilon, dtype, param_dtype, ...])

Layer normalization (https://arxiv.org/abs/1607.06450).

GroupNorm([num_groups, group_size, epsilon, ...])

Group normalization (arxiv.org/abs/1803.08494).

RMSNorm([epsilon, dtype, param_dtype, ...])

RMS Layer normalization (https://arxiv.org/abs/1910.07467).

SpectralNorm(layer_instance[, n_steps, ...])

Spectral normalization.

WeightNorm(layer_instance[, epsilon, dtype, ...])

L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf).

Sequential(layers[, parent, name])

Applies a linear chain of Modules.

Dropout(rate[, broadcast_dims, ...])

Create a dropout layer.

SelfAttention(num_heads[, dtype, ...])

Self-attention special case of multi-head dot-product attention.

MultiHeadDotProductAttention(num_heads[, ...])

Multi-head dot-product attention.

RNNCellBase([parent, name])

RNN cell base class.

LSTMCell(features[, gate_fn, activation_fn, ...])

LSTM cell.

OptimizedLSTMCell(features[, gate_fn, ...])

More efficient LSTM Cell that concatenates state components before matmul.

GRUCell(features[, gate_fn, activation_fn, ...])

GRU cell.

RNN(cell[, time_major, return_carry, ...])

The RNN module takes any RNNCellBase instance and applies it over a sequence using flax.linen.scan().

Bidirectional(forward_rnn, backward_rnn[, ...])

Processes the input in both directions and merges the results.

max_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the maximum of a window slice.

avg_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the average over a window.

pool(inputs, init, reduce_fn, window_shape, ...)

Helper function to define pooling functions.

dot_product_attention_weights(query, key[, ...])

Computes dot-product attention weights given query and key.

dot_product_attention(query, key, value[, ...])

Computes dot-product attention given query, key, and value.

make_attention_mask(query_input, key_input)

Mask-making helper for attention weights.

make_causal_mask(x[, extra_batch_dims, dtype])

Make a causal mask for self-attention.