flax.linen.ConvLocal#
- class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Local convolution Module wrapping lax.conv_general_dilated_local.
- features#
number of convolution filters.
- kernel_size#
shape of the convolutional kernel.
- strides#
an integer or a sequence of n integers, representing the inter-window strides (default: 1).
- padding#
either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ‘CAUSAL’ padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.
- input_dilation#
an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.
- kernel_dilation#
an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.
- feature_group_count#
integer, default 1. If specified divides the input features into groups.
- use_bias#
whether to add a bias to the output (default: True).
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- kernel_init#
initializer for the convolutional kernel.
- bias_init#
initializer for the bias.
- Parameters:
features (
int
) –padding (
Union
[str
,int
,Sequence
[Union
[int
,Tuple
[int
,int
]]]]) –feature_group_count (
int
) –use_bias (
bool
) –param_dtype (
Any
) –precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) –conv_general_dilated_cls (
Any
) –parent (
Union
[Type
[Module
],Scope
,Type
[_Sentinel
],None
]) –