flax.linen.while_loop#
- flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#
Lifted version of jax.lax.while_loop.
The lifted scope is passed to cond_fn and body_fn. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling body_fn once manually before calling while_loop if variable initialization is required.
Example:
class WhileLoopExample(nn.Module): @nn.compact def __call__(self, x): def cond_fn(mdl, c): return mdl.variables['state']['acc'] < 10 def body_fn(mdl, c): acc = mdl.variable('state', 'acc', lambda: jnp.array(0)) acc.value += 1 y = nn.Dense(c.shape[-1])(c) return y c = x if self.is_mutable_collection('params'): return body_fn(self, c) else: return nn.while_loop(cond_fn, body_fn, self, c, carry_variables='state') k = random.key(0) x = jnp.ones((2, 2)) intial_vars = WhileLoopExample().init(k, x) result, state = WhileLoopExample().apply(intial_vars, x, mutable=['state'])
- Parameters:
cond_fn (
Callable
[[TypeVar
(ModuleT
, bound=Module
),TypeVar
(C
)],bool
]) – Should return True as long as the loop should continue.body_fn (
Callable
[[TypeVar
(ModuleT
, bound=Module
),TypeVar
(C
)],TypeVar
(C
)]) – The body of the while loop.mdl (
TypeVar
(ModuleT
, bound=Module
)) – The Module which should be lifted into the loop.init (
TypeVar
(C
)) – The initial state passed to the loopcarry_variables (
Union
[bool
,str
,Collection
[str
],DenyList
]) – collections that are carried through the loop and are therefore mutable (default: none).broadcast_variables (
Union
[bool
,str
,Collection
[str
],DenyList
]) – collections that are closed over and are therefore read-only (default: all collections)split_rngs (
Mapping
[Union
[bool
,str
,Collection
[str
],DenyList
],bool
]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.
- Return type:
TypeVar
(C
)- Returns:
The final state after executing the while loop.