from typing import Optional, Tuple, Union, cast
import torch
from torch.nn.modules.utils import _pair
from kornia.core import Module, Tensor, concatenate, pad
PadType = Union[Tuple[int, int], Tuple[int, int, int, int]]
[docs]def compute_padding(
original_size: Union[int, Tuple[int, int]], window_size: Union[int, Tuple[int, int]]
) -> Tuple[int, int, int, int]:
r"""Compute required padding to ensure chaining of :func:`extract_tensor_patches` and
:func:`combine_tensor_patches` produces expected result.
Args:
original_size: the size of the original tensor.
window_size: the size of the sliding window used while extracting patches.
Return:
The required padding for `(top, bottom, left, right)` as a tuple of 4 ints.
Example:
>>> image = torch.arange(12).view(1, 1, 4, 3)
>>> padding = compute_padding((4,3), (3,3))
>>> out = extract_tensor_patches(image, window_size=(3, 3), stride=(3, 3), padding=padding)
>>> combine_tensor_patches(out, original_size=(4, 3), window_size=(3, 3), stride=(3, 3), unpadding=padding)
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]]])
.. note::
This function is supposed to be used in conjunction with :func:`extract_tensor_patches`
and :func:`combine_tensor_patches`.
"""
original_size = cast(Tuple[int, int], _pair(original_size))
window_size = cast(Tuple[int, int], _pair(window_size))
def paddim(dim1: int, dim2: int) -> Tuple[int, int]:
if dim1 % dim2 == 0:
p1 = 0
p2 = 0
else:
tmp = dim2 - (dim1 % dim2)
if tmp % 2 == 0:
p1 = p2 = tmp // 2
else:
p1 = tmp
p2 = 0
return p1, p2
padt, padb = paddim(original_size[0], window_size[0])
padl, padr = paddim(original_size[1], window_size[1])
return (padt, padb, padl, padr)
[docs]class CombineTensorPatches(Module):
r"""Module that combine patches from tensors.
In the simplest case, the output value of the operator with input size
:math:`(B, N, C, H_{out}, W_{out})` is :math:`(B, C, H, W)`.
where
- :math:`B` is the batch size.
- :math:`N` denotes the total number of extracted patches stacked in
- :math:`C` denotes the number of input channels.
- :math:`H`, :math:`W` the input height and width of the input in pixels.
- :math:`H_{out}`, :math:`W_{out}` denote to denote to the patch size
defined in the function signature.
left-right and top-bottom order.
* :attr:`original_size` is the size of the original image prior to
extracting tensor patches and defines the shape of the output patch.
* :attr:`window_size` is the size of the sliding window used while
extracting tensor patches.
* :attr:`unpadding` is the amount of padding to be removed. This value
must be the same as padding used while extracting tensor patches.
The parameters :attr:`original_size`, :attr:`window_size`, and :attr:`unpadding` can
be either:
- a single ``int`` -- in which case the same value is used for the
height and width dimension.
- a ``tuple`` of two ints -- in which case, the first `int` is used for
the height dimension, and the second `int` for the width dimension.
:attr:`unpadding` can also be a ``tuple`` of four ints -- in which case, the
first two ints are for the height dimension while the last two ints are for
the width dimension.
Args:
patches: patched tensor.
original_size: the size of the original tensor and the output patch size.
window_size: the size of the sliding window used.
unpadding: remove the padding added to both side of the input.
Shape:
- Input: :math:`(B, N, C, H_{out}, W_{out})`
- Output: :math:`(B, C, H, W)`
Example:
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
>>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]])
.. note::
This function is supposed to be used in conjunction with :class:`ExtractTensorPatches`.
"""
def __init__(
self,
original_size: Union[int, Tuple[int, int]],
window_size: Union[int, Tuple[int, int]],
unpadding: Union[int, PadType] = 0,
) -> None:
super().__init__()
self.original_size: Tuple[int, int] = _pair(original_size)
self.window_size: Tuple[int, int] = _pair(window_size)
self.unpadding: PadType = _pair(unpadding)
def forward(self, input: Tensor) -> Tensor:
return combine_tensor_patches(
input, self.original_size, self.window_size, stride=self.window_size, unpadding=self.unpadding
)
[docs]def combine_tensor_patches(
patches: Tensor,
original_size: Union[int, Tuple[int, int]],
window_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
unpadding: Union[int, PadType] = 0,
) -> Tensor:
r"""Restore input from patches.
See :class:`~kornia.contrib.CombineTensorPatches` for details.
Args:
patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.
original_size: the size of the original tensor and the output patch size.
window_size: the size of the sliding window used while extracting patches.
stride: stride of the sliding window.
unpadding: remove the padding added to both side of the input.
Return:
The combined patches in an image tensor with shape :math:`(B, C, H, W)`.
Example:
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
>>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]])
.. note::
This function is supposed to be used in conjunction with :func:`extract_tensor_patches`.
"""
if len(patches.shape) != 5:
raise ValueError(f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}")
original_size = cast(Tuple[int, int], _pair(original_size))
window_size = cast(Tuple[int, int], _pair(window_size))
stride = cast(Tuple[int, int], _pair(stride))
if stride[0] != window_size[0] or stride[1] != window_size[1]:
raise NotImplementedError(
f"Only stride == window_size is supported. Got {stride} and {window_size}."
"Please feel free to drop a PR to Kornia Github."
)
if unpadding:
unpadding = cast(PadType, _pair(unpadding))
if len(unpadding) not in [2, 4]:
raise AssertionError("Unpadding must be either an int, tuple of two ints or tuple of four ints")
if len(unpadding) == 2:
pad_vert = _pair(unpadding[0])
pad_horz = _pair(unpadding[1])
else:
pad_vert = unpadding[:2]
pad_horz = unpadding[2:]
unpadding = cast(Tuple[int, int, int, int], pad_horz + pad_vert)
hpad_check = (original_size[0] + unpadding[2] + unpadding[3]) % window_size[0] == 0
wpad_check = (original_size[1] + unpadding[0] + unpadding[1]) % window_size[1] == 0
if not hpad_check or not wpad_check:
raise NotImplementedError("Insufficient padding")
window_size = (
(original_size[0] + (unpadding[2] + unpadding[3])) // window_size[0],
(original_size[1] + (unpadding[0] + unpadding[1])) // window_size[1],
)
patches_tensor = patches.view(-1, window_size[0], window_size[1], *patches.shape[-3:])
restored_tensor = concatenate(torch.chunk(patches_tensor, window_size[0], 1), -2).squeeze(1)
restored_tensor = concatenate(torch.chunk(restored_tensor, window_size[1], 1), -1).squeeze(1)
if unpadding:
unpadding = cast(Tuple[int, int, int, int], unpadding)
restored_tensor = pad(restored_tensor, [-i for i in unpadding])
return restored_tensor
def _extract_tensor_patchesnd(input: Tensor, window_sizes: Tuple[int, ...], strides: Tuple[int, ...]) -> Tensor:
batch_size, num_channels = input.size()[:2]
dims = range(2, input.dim())
for dim, patch_size, stride in zip(dims, window_sizes, strides):
input = input.unfold(dim, patch_size, stride)
input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous()
return input.view(batch_size, -1, num_channels, *window_sizes)