flax.linen.with_logical_partitioning#
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#
Wraps a function’s return value with LogicallyPartitioned.
Example:
kernel_init = with_logical_partitioning( nn.initializers.lecun_normal, (None, "data")) partitioned_dense = nn.Dense(features, kernel_init=kernel_init)
- Parameters:
fn (
Callable
[...
,Any
]) – The function to be wrapped. Typically this is an initializer.names (
Tuple
[Optional
[str
],...
]) – The logical axis passed toLogicallyPartitioned
.mesh (
Optional
[Mesh
]) – The mesh to use for the partitioning. If None, the global mesh resource is used if available.rules (
Optional
[Sequence
[Tuple
[str
,Union
[str
,Tuple
[str
],None
]]]]) – Optional logical to mesh rules use. If None, the global rules are used if available.
- Return type:
- Returns:
A function wrapping
fn
that will return an instance ofLogicallyPartitioned
.