flax.linen.make_attention_mask#
- flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].
- Parameters:
query_input (
Any
) – a batched, flat input of query_length sizekey_input (
Any
) – a batched, flat input of key_length sizepairwise_fn (
Callable
[...
,Any
]) – broadcasting elementwise comparison functionextra_batch_dims (
int
) – number of extra batch dims to add singleton axes for, none by defaultdtype (
Any
) – mask return dtype
- Returns:
A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.