Source code for kornia.augmentation.container.dispatcher

from typing import List, Tuple, Union

import torch.nn as nn
from torch import Tensor

from kornia.augmentation.container.augment import AugmentationSequential


[docs]class ManyToManyAugmentationDispather(nn.Module): """Dispatches different augmentations to different inputs element-wisely. Args: augmentations: *args: a list of kornia AugmentationSequential modules. Examples: >>> import torch >>> input_1, input_2 = torch.randn(2, 3, 5, 6), torch.randn(2, 3, 5, 6) >>> mask_1, mask_2 = torch.ones(2, 3, 5, 6), torch.ones(2, 3, 5, 6) >>> aug_list = ManyToManyAugmentationDispather( ... AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask",], ... ), ... AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask",], ... ) ... ) >>> output = aug_list((input_1, mask_1), (input_2, mask_2)) """ def __init__(self, *augmentations: AugmentationSequential) -> None: super().__init__() self._check_consistency(*augmentations) self.augmentations = augmentations def _check_consistency(self, *augmentations: AugmentationSequential) -> bool: for i, aug in enumerate(augmentations): if not isinstance(aug, AugmentationSequential): raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.") return True
[docs] def forward(self, *input: Union[List[Tensor], List[Tuple[Tensor]]]) -> Union[List[Tensor], List[Tuple[Tensor]]]: return [aug(*inp) for inp, aug in zip(input, self.augmentations)]
[docs]class ManyToOneAugmentationDispather(nn.Module): """Dispatches different augmentations to a single input and returns a list. Same `datakeys` must be applied across different augmentations. By default, with input (image, mask), the augmentations must not mess it as (mask, image) to avoid unexpected errors. This check can be cancelled with `strict=False` if needed. Args: augmentations: *args: a list of kornia AugmentationSequential modules. Examples: >>> import torch >>> input = torch.randn(2, 3, 5, 6) >>> mask = torch.ones(2, 3, 5, 6) >>> aug_list = ManyToManyAugmentationDispather( ... AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask",], ... ), ... AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask",], ... ) ... ) >>> output = aug_list((input, mask)) """ def __init__(self, *augmentations: AugmentationSequential, strict: bool = True) -> None: super().__init__() self.strict = strict self._check_consistency(*augmentations) self.augmentations = augmentations def _check_consistency(self, *augmentations: AugmentationSequential) -> bool: for i, aug in enumerate(augmentations): if not isinstance(aug, AugmentationSequential): raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.") if self.strict and i != 0 and aug.data_keys != augmentations[i - 1].data_keys: raise RuntimeError( f"Different `data_keys` between {i - 1} and {i} elements, " f"got {aug.data_keys} and {augmentations[i - 1].data_keys}." ) return True
[docs] def forward(self, *input: Union[Tensor, Tuple[Tensor]]) -> Union[List[Tensor], List[Tuple[Tensor]]]: return [aug(*input) for aug in self.augmentations]