flax.linen.MultiHeadDotProductAttention#
- class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Multi-head dot-product attention.
- num_heads#
number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- dtype#
the dtype of the computation (default: infer from inputs and params)
- param_dtype#
the dtype passed to parameter initializers (default: float32)
- qkv_features#
dimension of the key, query, and value.
- out_features#
dimension of the last projection
- broadcast_dropout#
bool: use a broadcasted dropout along batch dims.
- dropout_rate#
dropout rate
- deterministic#
if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- kernel_init#
initializer for the kernel of the Dense layers.
- bias_init#
initializer for the bias of the Dense layers.
- use_bias#
bool: whether pointwise QKVO dense transforms use bias.
- attention_fn#
dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`
- decode#
whether to prepare and use an autoregressive cache.
- normalize_qk#
should QK normalization be applied (arxiv.org/abs/2302.05442).
- Parameters:
num_heads (
int
) –param_dtype (
Any
) –broadcast_dropout (
bool
) –dropout_rate (
float
) –precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) –kernel_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Any
]) –use_bias (
bool
) –decode (
bool
) –normalize_qk (
bool
) –qkv_dot_general_cls (
Any
) –out_dot_general_cls (
Any
) –parent (
Union
[Type
[Module
],Scope
,Type
[_Sentinel
],None
]) –