jax.numpy.matmul#

jax.numpy.matmul(a, b, *, precision=None, preferred_element_type=None)[source]#

Matrix product of two arrays.

LAX-backend implementation of numpy.matmul().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating separate precision for each argument.

Original docstring below.

Parameters:
Returns:

y – The matrix product of the inputs. This is a scalar only when both x1, x2 are 1-d vectors.

Return type:

ndarray