Source code for kornia.geometry.transform.homography_warper

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils import create_meshgrid

from .imgwarp import homography_warp, warp_grid

__all__ = ["HomographyWarper"]


[docs]class HomographyWarper(nn.Module): r"""Warp tensors by homographies. .. math:: X_{dst} = H_{src}^{\{dst\}} * X_{src} Args: height: The height of the destination tensor. width: The width of the destination tensor. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``. normalized_coordinates: whether to use a grid with normalized coordinates. align_corners: interpolation flag. """ _warped_grid: Optional[torch.Tensor] def __init__( self, height: int, width: int, mode: str = 'bilinear', padding_mode: str = 'zeros', normalized_coordinates: bool = True, align_corners: bool = False, ) -> None: super().__init__() self.width: int = width self.height: int = height self.mode: str = mode self.padding_mode: str = padding_mode self.normalized_coordinates: bool = normalized_coordinates self.align_corners: bool = align_corners # create base grid to compute the flow self.grid: torch.Tensor = create_meshgrid(height, width, normalized_coordinates=normalized_coordinates) # initialice the warped destination grid self._warped_grid = None def precompute_warp_grid(self, src_homo_dst: torch.Tensor) -> None: r"""Compute and store internally the transformations of the points. Useful when the same homography/homographies are reused. Args: src_homo_dst: Homography or homographies (stacked) to transform all points in the grid. Shape of the homography has to be :math:`(1, 3, 3)` or :math:`(N, 1, 3, 3)`. The homography assumes normalized coordinates [-1, 1] if normalized_coordinates is True. """ self._warped_grid = warp_grid(self.grid, src_homo_dst) def forward(self, patch_src: torch.Tensor, src_homo_dst: Optional[torch.Tensor] = None) -> torch.Tensor: r"""Warp a tensor from source into reference frame. Args: patch_src: The tensor to warp. src_homo_dst: The homography or stack of homographies from destination to source. The homography assumes normalized coordinates [-1, 1] if normalized_coordinates is True. Return: Patch sampled at locations from source to destination. Shape: - Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)` - Output: :math:`(N, C, H, W)` Example: >>> input = torch.rand(1, 3, 32, 32) >>> homography = torch.eye(3).view(1, 3, 3) >>> warper = HomographyWarper(32, 32) >>> # without precomputing the warp >>> output = warper(input, homography) # NxCxHxW >>> # precomputing the warp >>> warper.precompute_warp_grid(homography) >>> output = warper(input) # NxCxHxW """ _warped_grid = self._warped_grid if src_homo_dst is not None: warped_patch = homography_warp( patch_src, src_homo_dst, (self.height, self.width), mode=self.mode, padding_mode=self.padding_mode, align_corners=self.align_corners, normalized_coordinates=self.normalized_coordinates, ) elif _warped_grid is not None: if not _warped_grid.device == patch_src.device: raise TypeError( "Patch and warped grid must be on the same device. \ Got patch.device: {} warped_grid.device: {}. Whether \ recall precompute_warp_grid() with the correct device \ for the homograhy or change the patch device.".format( patch_src.device, _warped_grid.device ) ) warped_patch = F.grid_sample( patch_src, _warped_grid, mode=self.mode, padding_mode=self.padding_mode, align_corners=self.align_corners, ) else: raise RuntimeError( "Unknown warping. If homographies are not provided \ they must be preset using the method: \ precompute_warp_grid()." ) return warped_patch