from torch import Tensor, float16, float32, float64
import kornia
from kornia.augmentation.base import _AugmentationBase
from kornia.augmentation.utils import _transform_input, _validate_input_dtype
[docs]class AugmentationBase2D(_AugmentationBase):
r"""AugmentationBase2D base class for customized augmentation implementations.
For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the
"compute_transformation" is only required when passing "return_transform" as True.
Args:
p: probability for applying an augmentation. This param controls the augmentation probabilities
element-wise for a batch.
p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
probabilities batch-wise.
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
form ``False``.
"""
def __check_batching__(self, input: Tensor):
if isinstance(input, tuple):
inp, mat = input
if len(inp.shape) == 4:
if len(mat.shape) != 3:
raise AssertionError('Input tensor is in batch mode ' 'but transformation matrix is not')
if mat.shape[0] != inp.shape[0]:
raise AssertionError(
f"In batch dimension, input has {inp.shape[0]} but transformation matrix has {mat.shape[0]}"
)
elif len(inp.shape) in (2, 3):
if len(mat.shape) != 2:
raise AssertionError("Input tensor is in non-batch mode but transformation matrix is not")
else:
raise ValueError(f'Unrecognized output shape. Expected 2, 3, or 4, got {len(inp.shape)}')
def transform_tensor(self, input: Tensor) -> Tensor:
"""Convert any incoming (H, W), (C, H, W) and (B, C, H, W) into (B, C, H, W)."""
_validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
return _transform_input(input)
def identity_matrix(self, input) -> Tensor:
"""Return 3x3 identity matrix."""
return kornia.eye_like(3, input)