Activation functions#
Activation functions.
- class flax.linen.activation.PReLU(param_dtype=<class 'jax.numpy.float32'>, negative_slope_init=0.01, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Parametric Rectified Linear Unit (PReLU) activation function.
Note that PReLU is a Flax layer and not a simple activation function, so it needs to be initialized before being called.
- Example usage::
x = nn.PReLU()(x)
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- negative_slope_init#
the value to initialize the negative slope (default 0.01).
- Parameters:
- flax.linen.activation.celu(x, alpha=1.0)[source]#
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- flax.linen.activation.elu(x, alpha=1.0)[source]#
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters:
- Return type:
- Returns:
An array.
See also
- flax.linen.activation.gelu(x, approximate=True)[source]#
Gaussian error linear unit activation function.
If
approximate=False
, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]If
approximate=True
, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- flax.linen.activation.glu(x, axis=-1)[source]#
Gated linear unit activation function.
Computes the function:
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]where the array is split into two along
axis
. The size of theaxis
dimension must be divisible by two.- Parameters:
- Return type:
- Returns:
An array.
See also
- flax.linen.activation.hard_sigmoid(x)[source]#
Hard Sigmoid activation function.
Computes the element-wise function
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.hard_silu(x)[source]#
Hard SiLU (swish) activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]Both
hard_silu()
andhard_swish()
are aliases for the same function.- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.hard_swish(x)#
Hard SiLU (swish) activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]Both
hard_silu()
andhard_swish()
are aliases for the same function.- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.hard_tanh(x)[source]#
Hard \(\mathrm{tanh}\) activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
- flax.linen.activation.leaky_relu(x, negative_slope=0.01)[source]#
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.- Parameters:
- Return type:
- Returns:
An array.
See also
- flax.linen.activation.log_sigmoid(x)[source]#
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.log_softmax(x, axis=-1, where=None, initial=None)[source]#
Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input arrayaxis (
Union
[int
,tuple
[int
,...
],None
]) – the axis or axes along which thelog_softmax
should be computed. Either an integer or a tuple of integers.where (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) – Elements to include in thelog_softmax
.initial (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) – The minimum value used to shift the input array. Must be present whenwhere
is not None.
- Return type:
- Returns:
An array.
See also
- flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]#
Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
scipy.special.logsumexp()
.Original docstring below.
- Parameters:
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Return type:
- Returns:
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
- flax.linen.activation.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
One-hot encodes the given indices.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters:
- Return type:
- flax.linen.activation.relu(x)[source]#
Rectified linear unit activation function.
Computes the element-wise function:
\[\mathrm{relu}(x) = \max(x, 0)\]except under differentiation, we take:
\[\nabla \mathrm{relu}(0) = 0\]For more information see Numerical influence of ReLU’(0) on backpropagation.
- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
Example
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
See also
- flax.linen.activation.selu(x)[source]#
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.sigmoid(x)[source]#
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.silu(x)[source]#
SiLU (a.k.a. swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()
andsilu()
are both aliases for the same function.- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.soft_sign(x)[source]#
Soft-sign activation function.
Computes the element-wise function
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
- flax.linen.activation.softmax(x, axis=-1, where=None, initial=None)[source]#
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input arrayaxis (
Union
[int
,tuple
[int
,...
],None
]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.where (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) – Elements to include in thesoftmax
.initial (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) – The minimum value used to shift the input array. Must be present whenwhere
is not None.
- Return type:
- Returns:
An array.
See also
- flax.linen.activation.softplus(x)[source]#
Softplus activation function.
Computes the element-wise function
\[\mathrm{softplus}(x) = \log(1 + e^x)\]
- flax.linen.activation.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
Normalizes an array by subtracting
mean
and dividing by \(\sqrt{\mathrm{variance}}\).- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) –mean (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) –variance (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) –epsilon (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) –where (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
]) –
- Return type:
- flax.linen.activation.swish(x)#
SiLU (a.k.a. swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()
andsilu()
are both aliases for the same function.- Parameters:
x (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – input array- Return type:
- Returns:
An array.
See also
- flax.linen.activation.tanh(x, /)#
Compute hyperbolic tangent element-wise.
LAX-backend implementation of
numpy.tanh()
.Original docstring below.
Equivalent to
np.sinh(x)/np.cosh(x)
or-1j * np.tan(1j*x)
.- Parameters:
x (array_like) – Input array.
- Returns:
y – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar.
- Return type:
ndarray
References
Summary
|
Parametric Rectified Linear Unit (PReLU) activation function. |
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU (swish) activation function |
|
Hard SiLU (swish) activation function |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
Compute the log of the sum of exponentials of input elements. |
|
One-hot encodes the given indices. |
|
Rectified linear unit activation function. |
|
Rectified Linear Unit 6 activation function. |
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
SiLU (a.k.a. |
|
Soft-sign activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Normalizes an array by subtracting |
|
SiLU (a.k.a. |
|
Compute hyperbolic tangent element-wise. |