# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention core modules for Flax."""
import functools
from typing import Any, Callable, Optional, Tuple, Union, overload
import warnings
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.linear import DotGeneralT
from flax.linen.linear import PrecisionLike
from flax.linen.module import compact
from flax.linen.module import merge_param
from flax.linen.module import Module
from flax.linen.normalization import LayerNorm
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
PRNGKey = jax.Array
Shape = Tuple[int, ...]
Dtype = Any
Array = Any
[docs]
def dot_product_attention_weights(
query: Array,
key: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention weights given query and key.
Used by :func:`dot_product_attention`, which is what you'll most likely use.
But if you want access to the attention weights for introspection, then
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Returns:
Output of shape `[batch..., num_heads, q_length, kv_length]`.
"""
query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype
assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
# normalize the attention weights
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
attn_weights = attn_weights * multiplier
return attn_weights
[docs]
def dot_product_attention(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
num_heads, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Returns:
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
"""
query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
attn_weights = dot_product_attention_weights(
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
deterministic,
dtype,
precision,
)
# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
)
[docs]
class MultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
"""
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
normalize_qk: bool = False
# Deprecated, will be removed.
qkv_dot_general: Optional[DotGeneralT] = None
out_dot_general: Optional[DotGeneralT] = None
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None
@overload
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
):
...
@overload
def __call__(
self,
inputs_q: Array,
*,
inputs_kv: Array = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
):
...
@compact
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
inputs_kv: Optional[Array] = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None
):
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of
inputs_q (self attention).
If only inputs_v is None, it will copy the value of inputs_k.
Args:
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
inputs_k will copy the value of inputs_q.
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
inputs_v will copy the value of inputs_k.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If
None, inputs_kv will copy the value of inputs_q. This arg will be
deprecated soon. Use inputs_k and inputs_v instead.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
dropout_rng: optional rng key to pass to the attention layer's dropout
mask. Otherwise, self.make_rng('dropout') is used instead.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
if inputs_kv is not None:
if inputs_k is not None or inputs_v is not None:
raise ValueError('If either `inputs_k` or `inputs_v` is not None, '
'`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` '
'and `inputs_v` must be None. We recommend using `inputs_k` and '
'`inputs_v` args, since `inputs_kv` will be deprecated soon. See '
'https://github.com/google/flax/discussions/3389 for more '
'information.')
inputs_k = inputs_v = inputs_kv
warnings.warn('The inputs_kv arg will be deprecated soon. '
'Use inputs_k and inputs_v instead. See '
'https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
else:
if inputs_k is None:
if inputs_v is not None:
raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. '
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
'value to `inputs_k` and leave `inputs_v` as None.')
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k
elif inputs_v.shape[-1] == inputs_v.shape[-2]:
warnings.warn(f"You are passing an array of shape {inputs_v.shape} "
"to the `inputs_v` arg, when you may have intended "
"to pass it to the `mask` arg. As of Flax version "
"0.7.4, the function signature of "
"MultiHeadDotProductAttention's `__call__` method "
"has changed to `__call__(inputs_q, inputs_k=None, "
"inputs_v=None, *, inputs_kv=None, mask=None, "
"deterministic=None)`. Use the kwarg `mask` instead. "
"See https://github.com/google/flax/discussions/3389 "
"and read the docstring for more information.",
DeprecationWarning)
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
f'Memory dimension ({qkv_features}) must be divisible by number of'
f' heads ({self.num_heads}).'
)
head_dim = qkv_features // self.num_heads
dense = functools.partial(
DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
precision=self.precision,
dot_general=self.qkv_dot_general,
dot_general_cls=self.qkv_dot_general_cls,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
query, key, value = (
dense(name='query')(inputs_q),
dense(name='key')(inputs_k),
dense(name='value')(inputs_v),
)
if self.normalize_qk:
# Normalizing query and key projections stabilizes training with higher
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
query = LayerNorm(name='query_ln', use_bias=False)(query) # type: ignore[call-arg]
key = LayerNorm(name='key_ln', use_bias=False)(key) # type: ignore[call-arg]
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
cached_key = self.variable(
'cache', 'cached_key', jnp.zeros, key.shape, key.dtype
)
cached_value = self.variable(
'cache', 'cached_value', jnp.zeros, value.shape, value.dtype
)
cache_index = self.variable(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
'expected query shape %s instead got %s.'
% (expected_shape, query.shape)
)
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices: tuple[Union[int, jax.Array], ...] = (0,) * len(batch_dims) + (
cur_index,
0,
0,
)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
# causal mask for cached decoder self-attention:
# our single query position should only attend to those key
# positions that have already been generated and cached,
# not the remaining zero elements.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(max_length) <= cur_index,
tuple(batch_dims) + (1, 1, max_length),
),
)
if (
self.dropout_rate > 0.0
): # Require `deterministic` only if using dropout.
m_deterministic = merge_param(
'deterministic', self.deterministic, deterministic
)
if not m_deterministic and dropout_rng is None:
dropout_rng = self.make_rng('dropout')
else:
m_deterministic = True
# apply attention
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
) # pytype: disable=wrong-keyword-args
# back to the original inputs dimensions
out = DenseGeneral(
features=features,
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
dot_general=self.out_dot_general,
dot_general_cls=self.out_dot_general_cls,
name='out', # type: ignore[call-arg]
)(x)
return out
[docs]
class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
@compact
def __call__( # type: ignore
self,
inputs_q: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None
):
"""Applies multi-head dot product self-attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
warnings.warn('SelfAttention will be deprecated soon. Use '
'`MultiHeadDotProductAttention.__call__(inputs_q)` instead. '
'See https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
return super().__call__(
inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng
)
# mask-making utility functions
[docs]
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32,
):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
attention weights will be `[batch..., heads, len_q, len_kv]` and this
function will produce `[batch..., 1, len_q, len_kv]`.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
extra_batch_dims: number of extra batch dims to add singleton axes for, none
by default
dtype: mask return dtype
Returns:
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
"""
mask = pairwise_fn(
jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
)
mask = jnp.expand_dims(mask, axis=-3)
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
return mask.astype(dtype)
[docs]
def make_causal_mask(
x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
) -> Array:
"""Make a causal mask for self-attention.
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
will be `[batch..., heads, len, len]` and this function will produce a
causal mask of shape `[batch..., 1, len, len]`.
Args:
x: input array of shape `[batch..., len]`
extra_batch_dims: number of batch dims to add singleton axes for, none by
default
dtype: mask return dtype
Returns:
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
"""
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
return make_attention_mask(
idxs,
idxs,
jnp.greater_equal,
extra_batch_dims=extra_batch_dims,
dtype=dtype,
)
def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array:
"""Combine attention masks.
Args:
*masks: set of attention mask arguments to combine, some can be None.
dtype: dtype for the returned mask.
Returns:
Combined mask, reduced by logical and, returns None if no masks given.
"""
masks_list = [m for m in masks if m is not None]
if not masks_list:
return None
assert all(
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
mask, *other_masks = masks_list
for other_mask in other_masks:
mask = jnp.logical_and(mask, other_mask)
return mask.astype(dtype)