Layers#
Linear Modules#
Pooling#
- flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#
Pools the input by taking the maximum of a window slice.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).
padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).
- Returns:
The maximum for each window slice.
- flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#
Pools the input by taking the average over a window.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).
padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).
count_include_pad – a boolean whether to include padded tokens in the average calculation (default: True).
- Returns:
The average for each window slice.
- flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#
Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
init – the initial value for the reduction
reduce_fn – a reduce function of the form (T, T) -> T.
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).
padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.
- Returns:
The output of the reduction for each window slice.
Normalization#
Combinators#
Stochastic#
Attention#
- flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#
Computes dot-product attention weights given query and key.
Used by
dot_product_attention()
, which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.- Parameters:
query (
Any
) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].key (
Any
) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].bias (
Optional
[Any
]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.mask (
Optional
[Any
]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.broadcast_dropout (
bool
) – bool: use a broadcasted dropout along batch dims.dropout_rng (
Optional
[Array
]) – JAX PRNGKey: to be used for dropoutdropout_rate (
float
) – dropout ratedeterministic (
bool
) – bool, deterministic or not (to apply dropout)dtype (
Optional
[Any
]) – the dtype of the computation (default: infer from inputs and params)precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) – numerical precision of the computation see jax.lax.Precision for details.
- Returns:
Output of shape [batch…, num_heads, q_length, kv_length].
- flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#
Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.
Note: query, key, value needn’t have any batch dimensions.
- Parameters:
query (
Any
) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].key (
Any
) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].value (
Any
) – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].bias (
Optional
[Any
]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.mask (
Optional
[Any
]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.broadcast_dropout (
bool
) – bool: use a broadcasted dropout along batch dims.dropout_rng (
Optional
[Array
]) – JAX PRNGKey: to be used for dropoutdropout_rate (
float
) – dropout ratedeterministic (
bool
) – bool, deterministic or not (to apply dropout)dtype (
Optional
[Any
]) – the dtype of the computation (default: infer from inputs)precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) – numerical precision of the computation see jax.lax.Precision for details.
- Returns:
Output of shape [batch…, q_length, num_heads, v_depth_per_head].
- flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].
- Parameters:
query_input (
Any
) – a batched, flat input of query_length sizekey_input (
Any
) – a batched, flat input of key_length sizepairwise_fn (
Callable
[...
,Any
]) – broadcasting elementwise comparison functionextra_batch_dims (
int
) – number of extra batch dims to add singleton axes for, none by defaultdtype (
Any
) – mask return dtype
- Returns:
A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.
- flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Make a causal mask for self-attention.
In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].
Recurrent#
Summary
|
A linear transformation applied over the last dimension of the input. |
|
A linear transformation with flexible axes. |
|
Convolution Module wrapping lax.conv_general_dilated. |
|
Convolution Module wrapping lax.conv_transpose. |
|
Local convolution Module wrapping lax.conv_general_dilated_local. |
|
Embedding Module. |
|
BatchNorm Module. |
|
Layer normalization (https://arxiv.org/abs/1607.06450). |
|
Group normalization (arxiv.org/abs/1803.08494). |
|
RMS Layer normalization (https://arxiv.org/abs/1910.07467). |
|
Spectral normalization. |
|
L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf). |
|
Applies a linear chain of Modules. |
|
Create a dropout layer. |
|
Self-attention special case of multi-head dot-product attention. |
|
Multi-head dot-product attention. |
|
RNN cell base class. |
|
LSTM cell. |
|
More efficient LSTM Cell that concatenates state components before matmul. |
|
GRU cell. |
|
The |
|
Processes the input in both directions and merges the results. |
|
Pools the input by taking the maximum of a window slice. |
|
Pools the input by taking the average over a window. |
|
Helper function to define pooling functions. |
|
Computes dot-product attention weights given query and key. |
|
Computes dot-product attention given query, key, and value. |
|
Mask-making helper for attention weights. |
|
Make a causal mask for self-attention. |