jax.lax.sort_key_val# jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]# Sorts keys along dimension and applies the same permutation to values. Parameters: keys (Array) – values (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – dimension (int) – is_stable (bool) – Return type: tuple[Array, Array]