jax.experimental.io_callback#
- jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[source]#
Calls an impure Python callback.
For more explanation, see External Callbacks.
- Parameters:
callback (
Callable
[...
,Any
]) – function to execute on the host. It is assumet to be an impure function. Ifcallback
is pure, usingjax.pure_callback()
instead may lead to more efficient execution.result_shape_dtypes (
Any
) – pytree whose leaves haveshape
anddtype
attributes, whose structure matches the expected output of the callback function at runtime.jax.ShapeDtypeStruct
is often used to define leaf values.*args (
Any
) – arguments to be passed to the callback functionsharding (
Optional
[SingleDeviceSharding
]) – optional sharding that specifies the device from which the callback should be invoked.ordered (
bool
) – boolean specifying whether sequential calls to callback must be ordered.**kwargs (
Any
) – keyword arguments to be passed to the callback function
- Returns:
- a pytree of
jax.Array
objects whose structure matches that of result_shape_dtypes
.
- a pytree of
- Return type:
result
See also
jax.pure_callback()
: callback designed for pure functions.jax.debug.callback()
: callback designed for general-purpose debugging.jax.debug.print()
: callback designed for printing.