flax.linen.dot_product_attention_weights#
- flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#
Computes dot-product attention weights given query and key.
Used by
dot_product_attention()
, which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.- Parameters:
query (
Any
) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].key (
Any
) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].bias (
Optional
[Any
]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.mask (
Optional
[Any
]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.broadcast_dropout (
bool
) – bool: use a broadcasted dropout along batch dims.dropout_rng (
Optional
[Array
]) – JAX PRNGKey: to be used for dropoutdropout_rate (
float
) – dropout ratedeterministic (
bool
) – bool, deterministic or not (to apply dropout)dtype (
Optional
[Any
]) – the dtype of the computation (default: infer from inputs and params)precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) – numerical precision of the computation see jax.lax.Precision for details.
- Returns:
Output of shape [batch…, num_heads, q_length, kv_length].