flax.linen.LSTMCell#

class flax.linen.LSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function jax.numpy.tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

LSTM cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

features#

number of output features.

gate_fn#

activation function used for gates (default: sigmoid).

activation_fn#

activation function used for output and memory update (default: tanh).

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).

bias_init#

initializer for the bias parameters (default: initializers.zeros_init())

dtype#

the dtype of the computation (default: infer from inputs and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

Parameters: