jax.lax.reshape#
- jax.lax.reshape(operand, new_sizes, dimensions=None)[source]#
Wraps XLA’s Reshape operator.
For inserting/removing dimensions of size 1, prefer using
lax.squeeze
/lax.expand_dims
. These preserve information about axis identity that may be useful for advanced transformation rules.- Parameters:
operand (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – array to be reshaped.new_sizes (
Sequence
[Union
[int
,Any
]]) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.dimensions (
Optional
[Sequence
[int
]]) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must matchoperand.shape
.
- Returns:
reshaped array.
- Return type:
out
Examples
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y Array([[0, 1, 2], [3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,)) Array([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0)) Array([0, 3, 1, 4, 2, 5], dtype=int32)