Skip to main content
Ctrl+K

jax.lax.sort_key_val#

jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]#

Sorts keys along dimension and applies the same permutation to values.

Parameters:
  • keys (Array) –

  • values (Union[Array, ndarray, bool_, number, bool, int, float, complex]) –

  • dimension (int) –

  • is_stable (bool) –

Return type:

tuple[Array, Array]

previous

jax.lax.sort

next

jax.lax.sqrt