flax.linen.remat_scan#
- flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#
Combines remat and scan for memory efficiency and constant time compilation.
remat_scan
allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.Example:
class BigModel(nn.Module): @nn.compact def __call__(self, x): DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) # 100x dense with O(sqrt(N)) memory for gradient computation return DenseStack(8, name="dense_stack")(x)
- Parameters:
target (
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])) – aModule
or a function taking aModule
as its first argument.lengths (
Optional
[Sequence
[int
]]) – number of loop iterations at the given level. The total number of iterations n = prod(lengths). each loop is rematerialized. This way the memory consumption is proportional to n^(1 / d) where d = len(lengths). Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.policy (
Optional
[Callable
[...
,bool
]]) – Experimental checkpoint policy, seejax.checkpoint
.variable_broadcast (
Union
[bool
,str
,Collection
[str
], DenyList]) – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.variable_carry (
Union
[bool
,str
,Collection
[str
], DenyList]) – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.variable_axes (
Mapping
[Union
[bool
,str
,Collection
[str
], DenyList],Union
[int
,In
[int
],Out
[int
]]]) – the variable collections that are scanned over. Defaults to{True: 0}
.split_rngs (
Mapping
[Union
[bool
,str
,Collection
[str
], DenyList],bool
]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to{True: True}
.
- Return type:
TypeVar
(Target
, bound=Union
[Type
[Module
],Callable
[...
,Any
]])- Returns:
A wrapped version of
target
that repeats itself prod(lengths) times.