jax.lib.xla_bridge.get_compile_options#
- jax.lib.xla_bridge.get_compile_options(num_replicas, num_partitions, device_assignment=None, use_spmd_partitioning=True, use_auto_spmd_partitioning=False, auto_spmd_partitioning_mesh_shape=None, auto_spmd_partitioning_mesh_ids=None, env_options_overrides=None, fdo_profile=None, detailed_logging=True)[source]#
Returns the compile options to use, as derived from flag values.
- Parameters:
num_replicas (
int
) – Number of replicas for which to compile.num_partitions (
int
) – Number of partitions for which to compile.device_assignment – Optional ndarray of jax devices indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with num_replicas and num_partitions.
use_spmd_partitioning (
bool
) – boolean indicating whether to enable SPMD or MPMD partitioning in XLA.use_auto_spmd_partitioning (
bool
) – boolean indicating whether to automatically generate XLA shardings for SPMD partitioner.auto_spmd_partitioning_mesh_shape (
Optional
[list
[int
]]) – device mesh shape used to create auto_spmd_partitioning search space.auto_spmd_partitioning_mesh_ids (
Optional
[list
[int
]]) – device ids used to create auto_spmd_partitioning search space.env_options_overrides (
Optional
[dict
[str
,str
]]) – dict of additional options parsed by the compilerfdo_profile (
Optional
[bytes
]) – Optional profile for feedback-directed optimization passed to XLA.detailed_logging (
bool
) – Is this an “interesting” computation about which XLA would be wise to log compilation information?
- Return type:
CompileOptions