Skip to main content
Ctrl+K

jax.extend.random.define_prng_impl#

jax.extend.random.define_prng_impl(*, key_shape, seed, split, random_bits, fold_in, name='<unnamed>', tag='?')[source]#
Parameters:
  • key_shape (tuple[int, ...]) –

  • seed (Callable[[Array], Array]) –

  • split (Callable[[Array, tuple[int, ...]], Array]) –

  • random_bits (Callable[[Array, int, tuple[int, ...]], Array]) –

  • fold_in (Callable[[Array, int], Array]) –

  • name (str) –

  • tag (str) –

Return type:

Hashable

previous

jax.extend.linear_util.wrap_init

next

jax.extend.random.seed_with_impl