flax.linen.with_logical_partitioning#

flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#

Wraps a function’s return value with LogicallyPartitioned.

Example:

kernel_init = with_logical_partitioning(
    nn.initializers.lecun_normal, (None, "data"))
partitioned_dense = nn.Dense(features, kernel_init=kernel_init)
Parameters:
  • fn (Callable[..., Any]) – The function to be wrapped. Typically this is an initializer.

  • names (Tuple[Optional[str], ...]) – The logical axis passed to LogicallyPartitioned.

  • mesh (Optional[Mesh]) – The mesh to use for the partitioning. If None, the global mesh resource is used if available.

  • rules (Optional[Sequence[Tuple[str, Union[str, Tuple[str], None]]]]) – Optional logical to mesh rules use. If None, the global rules are used if available.

Return type:

Callable[..., LogicallyPartitioned]

Returns:

A function wrapping fn that will return an instance of LogicallyPartitioned.