flax.linen.Embed#
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
- num_embeddings#
number of embeddings.
- features#
number of feature dimensions for each embedding.
- dtype#
the dtype of the embedding vectors (default: same as embedding).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- embedding_init#
embedding initializer.