flax.linen.while_loop#

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#

Lifted version of jax.lax.while_loop.

The lifted scope is passed to cond_fn and body_fn. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling body_fn once manually before calling while_loop if variable initialization is required.

Example:

class WhileLoopExample(nn.Module):
  @nn.compact
  def __call__(self, x):
    def cond_fn(mdl, c):
      return mdl.variables['state']['acc'] < 10
    def body_fn(mdl, c):
      acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
      acc.value += 1
      y = nn.Dense(c.shape[-1])(c)
      return y
    c = x
    if self.is_mutable_collection('params'):
      return body_fn(self, c)
    else:
      return nn.while_loop(cond_fn, body_fn, self, c,
                           carry_variables='state')

k = random.key(0)
x = jnp.ones((2, 2))
intial_vars = WhileLoopExample().init(k, x)
result, state = WhileLoopExample().apply(intial_vars, x, mutable=['state'])
Parameters:
  • cond_fn (Callable[[TypeVar(ModuleT, bound= Module), TypeVar(C)], bool]) – Should return True as long as the loop should continue.

  • body_fn (Callable[[TypeVar(ModuleT, bound= Module), TypeVar(C)], TypeVar(C)]) – The body of the while loop.

  • mdl (TypeVar(ModuleT, bound= Module)) – The Module which should be lifted into the loop.

  • init (TypeVar(C)) – The initial state passed to the loop

  • carry_variables (Union[bool, str, Collection[str], DenyList]) – collections that are carried through the loop and are therefore mutable (default: none).

  • broadcast_variables (Union[bool, str, Collection[str], DenyList]) – collections that are closed over and are therefore read-only (default: all collections)

  • split_rngs (Mapping[Union[bool, str, Collection[str], DenyList], bool]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

Return type:

TypeVar(C)

Returns:

The final state after executing the while loop.