Skip to main content
Ctrl+K

jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]#

Batch matrix multiplication.

Parameters:
  • lhs (Array) –

  • rhs (Array) –

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision]]) –

Return type:

Array

previous

jax.lax.atanh

next

jax.lax.bessel_i0e