jax.lax.slice#
- jax.lax.slice(operand, start_indices, limit_indices, strides=None)[source]#
Wraps XLA’s Slice operator.
- Parameters:
operand (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – an array to slicestart_indices (
Sequence
[int
]) – a sequence ofoperand.ndim
start indices.limit_indices (
Sequence
[int
]) – a sequence ofoperand.ndim
limit indices.strides (
Optional
[Sequence
[int
]]) – an optional sequence ofoperand.ndim
strides.
- Return type:
- Returns:
The sliced array
Examples
Here are some examples of simple two-dimensional slices:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2)) Array([[4, 5], [8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2)) Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)
These two examples are equivalent to the following Python slicing syntax:
>>> x[1:3, 0:2] Array([[4, 5], [8, 9]], dtype=int32)
>>> x[0:3, 0:4:2] Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)