Source code for kornia.augmentation._2d.base

from torch import Tensor, float16, float32, float64

import kornia
from kornia.augmentation.base import _AugmentationBase
from kornia.augmentation.utils import _transform_input, _validate_input_dtype


[docs]class AugmentationBase2D(_AugmentationBase): r"""AugmentationBase2D base class for customized augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the "compute_transformation" is only required when passing "return_transform" as True. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def __check_batching__(self, input: Tensor): if isinstance(input, tuple): inp, mat = input if len(inp.shape) == 4: if len(mat.shape) != 3: raise AssertionError('Input tensor is in batch mode ' 'but transformation matrix is not') if mat.shape[0] != inp.shape[0]: raise AssertionError( f"In batch dimension, input has {inp.shape[0]} but transformation matrix has {mat.shape[0]}" ) elif len(inp.shape) in (2, 3): if len(mat.shape) != 2: raise AssertionError("Input tensor is in non-batch mode but transformation matrix is not") else: raise ValueError(f'Unrecognized output shape. Expected 2, 3, or 4, got {len(inp.shape)}') def transform_tensor(self, input: Tensor) -> Tensor: """Convert any incoming (H, W), (C, H, W) and (B, C, H, W) into (B, C, H, W).""" _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64]) return _transform_input(input) def identity_matrix(self, input) -> Tensor: """Return 3x3 identity matrix.""" return kornia.eye_like(3, input)