flax.linen.ConvTranspose#

class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Convolution Module wrapping lax.conv_transpose.

features#

number of convolution filters.

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.

strides#

a sequence of n integers, representing the inter-window strides.

padding#

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.

kernel_dilation#

None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.

use_bias#

whether to add a bias to the output (default: True).

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the convolutional kernel.

bias_init#

initializer for the bias.

transpose_kernel#

if True flips spatial axes and swaps the input/output channel axes of the kernel.

Parameters: