flax.linen.scan#
- flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None)[source]#
A lifted version of
jax.lax.scan
.See
jax.lax.scan
for the unlifted scan in Jax.To improve consistency with
vmap
, this version of scan usesin_axes
andout_axes
to determine which arguments are scanned over and along which axis.scan
distinguishes between 3 different types of values inside the loop:scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.
carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.
broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.
The
target
should have the signature(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.Example:
>>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp ... >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ScanLSTM = nn.scan( ... nn.LSTMCell, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) >>> y, variables = module.init_with_output(jax.random.key(0), x)
Note that when providing a function to
nn.scan
, the scanning happens over all arguments starting from the third argument, as specified byin_axes
. The previous example could also be written using the functional form as:>>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ... cell = nn.LSTMCell(self.features) ... def body_fn(cell, carry, x): ... carry, y = cell(carry, x) ... return carry, y ... scan = nn.scan( ... body_fn, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))
You can also use
scan
to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:>>> class ResidualMLPBlock(nn.Module): ... @nn.compact ... def __call__(self, x, _): ... h = nn.Dense(features=2)(x) ... h = nn.relu(h) ... return x + h, None ... >>> class ResidualMLP(nn.Module): ... n_layers: int = 4 ... ... @nn.compact ... def __call__(self, x): ... ScanMLP = nn.scan( ... ResidualMLPBlock, variable_axes={'params': 0}, ... variable_broadcast=False, split_rngs={'params': True}, ... length=self.n_layers) ... x, _ = ScanMLP()(x, None) ... return x ... >>> model = ResidualMLP(n_layers=4) >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))
To reduce both compilation and memory usage, you can use
remat_scan()
which will in addition checkpoint each layer in the scan loop.- Parameters:
target (
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])) – aModule
or a function taking aModule
as its first argument.variable_axes (
Mapping
[Union
[bool
,str
,Collection
[str
], DenyList],Union
[int
,In
[int
],Out
[int
]]]) – the variable collections that are scanned over.variable_broadcast (
Union
[bool
,str
,Collection
[str
], DenyList]) – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.variable_carry (
Union
[bool
,str
,Collection
[str
], DenyList]) – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.split_rngs (
Mapping
[Union
[bool
,str
,Collection
[str
], DenyList],bool
]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.
out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.
length (
Optional
[int
]) – Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments.reverse (
bool
) – If true, scan from end to start in reverse order.unroll (
int
) – how many scan iterations to unroll within a single iteration of a loop (default: 1).data_transform (
Optional
[Callable
[...
,Any
]]) – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.metadata_params (
Mapping
[Any
,Any
]) – arguments dict passed to AxisMetadata instances in the variable tree.methods – If target is a Module, the methods of Module to scan over.
- Return type:
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])- Returns:
The scan function with the signature
(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.