jax.random.truncated_normal#
- jax.random.truncated_normal(key, lower, upper, shape=None, dtype=<class 'float'>)[source]#
Sample truncated standard normal random values with given shape and dtype.
The values are returned according to the probability density function:
\[f(x) \propto e^{-x^2/2}\]on the domain \(\rm{lower} < x < \rm{upper}\).
- Parameters:
key (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a PRNG key used as the random key.lower (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible withupper
.upper (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible withlower
.shape (
Union
[Sequence
[int
],NamedShape
,None
]) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible withlower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
.dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
]) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type:
- Returns:
A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
. Returns values in the open interval(lower, upper)
.