flax.linen.Embed#

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Embedding Module.

A parameterized function from integers [0, n) to d-dimensional vectors.

num_embeddings#

number of embeddings.

features#

number of feature dimensions for each embedding.

dtype#

the dtype of the embedding vectors (default: same as embedding).

param_dtype#

the dtype passed to parameter initializers (default: float32).

embedding_init#

embedding initializer.

Parameters: