flax.linen.OptimizedLSTMCell#
- class flax.linen.OptimizedLSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function jax.numpy.tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
More efficient LSTM Cell that concatenates state components before matmul.
The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.
The mathematical definition of the cell is the same as LSTMCell and as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- gate_fn#
activation function used for gates (default: sigmoid).
- activation_fn#
activation function used for output and memory update (default: tanh).
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init()).
- dtype#
the dtype of the computation (default: infer from inputs and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).