flax.linen.MultiHeadDotProductAttention#

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Multi-head dot-product attention.

num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

dtype#

the dtype of the computation (default: infer from inputs and params)

param_dtype#

the dtype passed to parameter initializers (default: float32)

qkv_features#

dimension of the key, query, and value.

out_features#

dimension of the last projection

broadcast_dropout#

bool: use a broadcasted dropout along batch dims.

dropout_rate#

dropout rate

deterministic#

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the kernel of the Dense layers.

bias_init#

initializer for the bias of the Dense layers.

use_bias#

bool: whether pointwise QKVO dense transforms use bias.

attention_fn#

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`

decode#

whether to prepare and use an autoregressive cache.

normalize_qk#

should QK normalization be applied (arxiv.org/abs/2302.05442).

Parameters: