flax.linen.map_variables#
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#
Map Variables inside a module.
map_variables
can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself.- Example::
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CausalDense(nn.Module): ... '''A dense layer that masks the weights such that the output is ... causal, i.e. output i only depends on input <= i. ... ''' ... features: int ... ... def apply_mask(self, variables): ... return (jax.tree_map(jnp.triu, variables) ... if not self.is_initializing() else variables) ... ... def setup(self): ... # temporary class ... _CausalDense = nn.map_variables( ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) ... self.dense = _CausalDense(features=self.features, use_bias=False) ... ... def __call__(self, x): ... return self.dense(x) ... >>> module = CausalDense(features=5) >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
- Parameters:
target (
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])) – the module or function to be transformed.mapped_collections (
Union
[bool
,str
,Collection
[str
], DenyList]) – the collection(s) to be transformed.trans_in_fn (
Callable
[...
,Any
]) – modifies the variables before applying the module or function.trans_out_fn (
Callable
[...
,Any
]) – modifies the variables after applying the module or function, it is only applied if eitherinit
ormutable
are not False.init (
bool
) – If True, variables are initialized before transformation.mutable (
bool
) – If True, the mapped variable collections will be mutable.rngs (
Union
[bool
,str
,Collection
[str
], DenyList]) – PRNGSequences added to the transformed scope (default: all).variables (
Union
[bool
,str
,Collection
[str
], DenyList]) – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).methods – If target is a Module, the methods of Module to map variables for.
- Return type:
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])- Returns:
a wrapped version of
target
that will map the specified collections.