Skip to main content
Ctrl+K

jax.lax.index_take#

jax.lax.index_take(src, idxs, axes)[source]#
Parameters:
  • src (Array) –

  • idxs (Array) –

  • axes (Sequence[int]) –

Return type:

Array

previous

jax.lax.index_in_dim

next

jax.lax.integer_pow