Source code for kornia.tracking.planar_tracker

from typing import Dict, Optional, Tuple

import torch

from kornia.core import Module, Tensor
from kornia.feature import DescriptorMatcher, GFTTAffNetHardNet, LocalFeatureMatcher, LoFTR
from kornia.feature.integrated import LocalFeature
from kornia.geometry.linalg import transform_points
from kornia.geometry.ransac import RANSAC
from kornia.geometry.transform import warp_perspective


[docs]class HomographyTracker(Module): r"""Module, which performs local-feature-based tracking of the target planar object in the sequence of the frames. Args: initial_matcher: image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher` or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.GFTTAffNetHardNet`. fast_matcher: fast image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher` or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.DescriptorMatcher`. ransac: homography estimation module. Default: :class:`~kornia.geometry.RANSAC`. minimum_inliers_num: threshold for number inliers for matching to be successful. """ def __init__( self, initial_matcher: Optional[LocalFeature] = None, fast_matcher: Optional[Module] = None, ransac: Optional[Module] = None, minimum_inliers_num: int = 30, ) -> None: super().__init__() self.initial_matcher = initial_matcher or ( LocalFeatureMatcher(GFTTAffNetHardNet(3000), DescriptorMatcher('smnn', 0.95)) ) self.fast_matcher = fast_matcher or LoFTR('outdoor') self.ransac = ransac or RANSAC('homography', inl_th=5.0, batch_size=4096, max_iter=10, max_lo_iters=10) self.minimum_inliers_num = minimum_inliers_num # placeholders self.target: Tensor self.target_initial_representation: Dict[str, Tensor] = {} self.target_fast_representation: Dict[str, Tensor] = {} self.previous_homography: Optional[Tensor] = None self.inliers_num: int = 0 self.keypoints0_num: int = 0 self.keypoints1_num: int = 0 self.reset_tracking() @property def device(self) -> torch.device: return self.target.device @property def dtype(self) -> torch.dtype: return self.target.dtype @torch.no_grad() def set_target(self, target: Tensor) -> None: self.target = target self.target_initial_representation = {} self.target_fast_representation = {} if hasattr(self.initial_matcher, 'extract_features') and isinstance( self.initial_matcher.extract_features, Module ): self.target_initial_representation = self.initial_matcher.extract_features(target) if hasattr(self.fast_matcher, 'extract_features') and isinstance(self.fast_matcher.extract_features, Module): self.target_fast_representation = self.fast_matcher.extract_features(target) def reset_tracking(self) -> None: self.previous_homography = None def no_match(self) -> Tuple[Tensor, bool]: self.inliers_num = 0 self.keypoints0_num = 0 self.keypoints1_num = 0 return torch.empty(3, 3, device=self.device, dtype=self.dtype), False
[docs] def match_initial(self, x: Tensor) -> Tuple[Tensor, bool]: """The frame `x` is matched with initial_matcher and verified with ransac.""" input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": x} for k, v in self.target_initial_representation.items(): input_dict[f'{k}0'] = v match_dict: Dict[str, Tensor] = self.initial_matcher(input_dict) keypoints0 = match_dict['keypoints0'][match_dict['batch_indexes'] == 0] keypoints1 = match_dict['keypoints1'][match_dict['batch_indexes'] == 0] self.keypoints0_num = len(keypoints0) self.keypoints1_num = len(keypoints1) if self.keypoints0_num < self.minimum_inliers_num: return self.no_match() H, inliers = self.ransac(keypoints0, keypoints1) self.inliers_num = inliers.sum().item() if self.inliers_num < self.minimum_inliers_num: return self.no_match() self.previous_homography = H.clone() return H, True
[docs] def track_next_frame(self, x: Tensor) -> Tuple[Tensor, bool]: """The frame `x` is prewarped according to the previous frame homography, matched with fast_matcher verified with ransac.""" if self.previous_homography is not None: # mypy, shut up Hwarp = self.previous_homography.clone()[None] # make a bit of border for safety Hwarp[:, 0:2, 0:2] = Hwarp[:, 0:2, 0:2] / 0.8 Hwarp[:, 0:2, 2] -= 10.0 Hinv = torch.inverse(Hwarp) h, w = self.target.shape[2:] frame_warped = warp_perspective(x, Hinv, (h, w)) input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": frame_warped} for k, v in self.target_fast_representation.items(): input_dict[f'{k}0'] = v match_dict = self.fast_matcher(input_dict) keypoints0 = match_dict['keypoints0'][match_dict['batch_indexes'] == 0] keypoints1 = match_dict['keypoints1'][match_dict['batch_indexes'] == 0] keypoints1 = transform_points(Hwarp, keypoints1) self.keypoints0_num = len(keypoints0) self.keypoints1_num = len(keypoints1) if self.keypoints0_num < self.minimum_inliers_num: self.reset_tracking() return self.no_match() H, inliers = self.ransac(keypoints0, keypoints1) self.inliers_num = inliers.sum().item() if self.inliers_num < self.minimum_inliers_num: self.reset_tracking() return self.no_match() self.previous_homography = H.clone() return H, True
[docs] def forward(self, x: Tensor) -> Tuple[Tensor, bool]: if self.previous_homography is not None: return self.track_next_frame(x) return self.match_initial(x)