Skip to main content
Ctrl+K

jax.lax.argmax#

jax.lax.argmax(operand, axis, index_dtype)[source]#

Computes the index of the maximum element along axis.

Parameters:
  • operand (Union[Array, ndarray, bool_, number, bool, int, float, complex]) –

  • axis (int) –

  • index_dtype (Union[str, type[Any], dtype, SupportsDType]) –

Return type:

Array

previous

jax.lax.approx_min_k

next

jax.lax.argmin