jax.experimental.multihost_utils.host_local_array_to_global_array#

jax.experimental.multihost_utils.host_local_array_to_global_array(local_inputs, global_mesh, pspecs)[source]#

Converts a host local value to a globally sharded jax.Array.

You can use this function to transition to jax.Array. Using jax.Array with pjit has the same semantics of using GDA with pjit i.e. all jax.Array inputs to pjit should be globally shaped.

If you are currently passing host local values to pjit, you can use this function to convert your host local values to global Arrays and then pass that to pjit. Example usage.

>>> from jax.experimental import multihost_utils 
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) 
>>>
>>> with mesh: 
>>>   global_out = pjitted_fun(global_inputs) 
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) 
Parameters:
  • local_inputs (Any) – A Pytree of host local values.

  • global_mesh (Mesh) – A jax.sharding.Mesh object.

  • pspecs (Any) – A Pytree of jax.sharding.PartitionSpec’s.