flax.linen.init#
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
Creates an init function to call
fn
with a bound module.Unlike
Module.init
this function returns a new function with the signature(rngs, *args, **kwargs) -> variables
. The rngs can be a dict of PRNGKeys or a single`PRNGKey
which is equivalent to passing a dict with one PRNGKey with the name “params”.The init function that is returned can be directly composed with JAX transformations like
jax.jit
:def f(foo, x): z = foo.encode(x) y = foo.decode(z) # ... return y foo = Foo() f_jitted = jax.jit(nn.init(f, foo)) variables = f_jitted(rng, x)
- Parameters:
fn (
Callable
[...
,Any
]) – The function that should be applied. The first argument passed will be a module instance of themodule
with variables and RNGs bound to it.module (
Module
) – TheModule
that will be used to bind variables and RNGs to. TheModule
passed as the first argument tofn
will be a clone of module.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates (
Union
[bool
,Callable
[[Module
,str
],bool
]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Return type:
Callable
[...
,Union
[FrozenDict
[str
,Mapping
[str
,Any
]],Dict
[str
,Any
]]]- Returns:
The init function wrapping
fn
.