flax.linen.SelfAttention#
- class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Self-attention special case of multi-head dot-product attention.
- Parameters:
num_heads (
int
) –param_dtype (
Any
) –broadcast_dropout (
bool
) –dropout_rate (
float
) –precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) –kernel_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Any
]) –use_bias (
bool
) –decode (
bool
) –normalize_qk (
bool
) –qkv_dot_general_cls (
Any
) –out_dot_general_cls (
Any
) –parent (
Union
[Type
[Module
],Scope
,Type
[_Sentinel
],None
]) –