from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.mix.base import MixAugmentationBaseV2
from kornia.constants import DataKey
from kornia.core import Tensor
__all__ = ["RandomJigsaw"]
[docs]class RandomJigsaw(MixAugmentationBaseV2):
r"""RandomJigsaw augmentation.
.. image:: https://raw.githubusercontent.com/kornia/data/main/random_jigsaw.png
Make Jigsaw puzzles for each image individually. To mix with different images in a
batch, referring to :class:`kornia.augmentation.RandomMosic`.
Args:
grid: the Jigsaw puzzle grid. e.g. (2, 2) means
each output will mix image patches in a 2x2 grid.
ensure_perm: to ensure the nonidentical patch permutation generation against
the original one.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
p: probability of applying the transformation for the whole batch.
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
to the batch form ``False``.
Examples:
>>> jigsaw = RandomJigsaw((4, 4))
>>> input = torch.randn(8, 3, 256, 256)
>>> out = jigsaw(input)
>>> out.shape
torch.Size([8, 3, 256, 256])
"""
def __init__(
self,
grid: Tuple[int, int] = (4, 4),
data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT],
p: float = 0.5,
same_on_batch: bool = False,
keepdim: bool = False,
ensure_perm: bool = True,
) -> None:
super().__init__(p=p, p_batch=1.0, same_on_batch=same_on_batch, keepdim=keepdim, data_keys=data_keys)
self._param_generator = rg.JigsawGenerator(grid, ensure_perm)
self.flags = dict(grid=grid)
def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], maybe_flags: Optional[Dict[str, Any]] = None
) -> Tensor:
# different from the Base class routine. This function will not refer to any non-transformation images.
to_apply = params['batch_prob']
input = input[to_apply].clone()
b, c, h, w = input.shape
perm = params["permutation"]
piece_size_h, piece_size_w = input.shape[-2] // self.flags["grid"][0], input.shape[-1] // self.flags["grid"][1]
# Convert to C BxN H' W'
input = (
input.unfold(2, piece_size_h, piece_size_w)
.unfold(3, piece_size_h, piece_size_w)
.reshape(b, c, -1, piece_size_h, piece_size_w)
.permute(1, 0, 2, 3, 4)
.reshape(c, -1, piece_size_h, piece_size_w)
)
perm = (perm + torch.arange(0, b, device=perm.device)[:, None] * perm.shape[1]).view(-1)
input = input[:, perm, :, :]
input = (
input.reshape(-1, b, self.flags["grid"][1], h, piece_size_w)
.permute(0, 1, 2, 4, 3)
.reshape(-1, b, h, w)
.permute(0, 1, 3, 2)
.permute(1, 0, 2, 3)
)
return input