Source code for kornia.feature.keynet

import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import TypedDict

from kornia.core import Module, Tensor, concatenate, tensor, where, zeros
from kornia.filters import SpatialGradient
from kornia.geometry.subpix import NonMaximaSuppression2d
from kornia.geometry.transform import pyrdown
from kornia.utils.helpers import map_location_to_cpu

from .laf import laf_from_center_scale_ori
from .orientation import PassLAF


class KeyNet_conf(TypedDict):
    num_filters: int
    num_levels: int
    kernel_size: int
    nms_size: int
    pyramid_levels: int
    up_levels: int
    scale_factor_levels: float
    s_mult: float


keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    'num_filters': 8,
    'num_levels': 3,
    'kernel_size': 5,
    # Extraction Parameters
    'nms_size': 15,
    'pyramid_levels': 4,
    'up_levels': 1,
    'scale_factor_levels': math.sqrt(2),
    's_mult': 22.0,
}

KeyNet_URL = "https://github.com/axelBarroso/Key.Net-Pytorch/raw/main/model/weights/keynet_pytorch.pth"


class _FeatureExtractor(Module):
    """Helper class for KeyNet.

    It loads both, the handcrafted and learnable blocks
    """

    def __init__(self):
        super().__init__()

        self.hc_block = _HandcraftedBlock()
        self.lb_block = _LearnableBlock()

    def forward(self, x: Tensor) -> Tensor:
        x_hc = self.hc_block(x)
        x_lb = self.lb_block(x_hc)
        return x_lb


class _HandcraftedBlock(Module):
    """Helper class for KeyNet, it defines the handcrafted filters within the Key.Net handcrafted block."""

    def __init__(self):
        super().__init__()
        self.spatial_gradient = SpatialGradient('sobel', 1)

    def forward(self, x: Tensor) -> Tensor:
        sobel = self.spatial_gradient(x)
        dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]

        sobel_dx = self.spatial_gradient(dx)
        dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]

        sobel_dy = self.spatial_gradient(dy)
        dyy = sobel_dy[:, :, 1, :, :]

        hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)

        return hc_feats


class _LearnableBlock(nn.Sequential):
    """Helper class for KeyNet.

    It defines the learnable blocks within the Key.Net
    """

    def __init__(self, in_channels: int = 10):
        super().__init__()

        self.conv0 = _KeyNetConvBlock(in_channels)
        self.conv1 = _KeyNetConvBlock()
        self.conv2 = _KeyNetConvBlock()

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2(self.conv1(self.conv0(x)))
        return x


def _KeyNetConvBlock(
    in_channels: int = 8,
    out_channels: int = 8,
    kernel_size: int = 5,
    stride: int = 1,
    padding: int = 2,
    dilation: int = 1,
):
    """Helper function for KeyNet.

    Default learnable convolutional block for KeyNet.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


[docs]class KeyNet(Module): """Key.Net model definition -- local feature detector (response function). This is based on the original code from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for more details. Args: pretrained: Download and set pretrained weights to the model. keynet_conf: Dict with initiliazation parameters. Do not pass it, unless you know what you are doing`. Returns: KeyNet response score. Shape: - Input: :math:`(B, 1, H, W)` - Output: :math:`(B, 1, H, W)` """ def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config): super().__init__() num_filters = keynet_conf['num_filters'] self.num_levels = keynet_conf['num_levels'] kernel_size = keynet_conf['kernel_size'] padding = kernel_size // 2 self.feature_extractor = _FeatureExtractor() self.last_conv = nn.Sequential( nn.Conv2d( in_channels=num_filters * self.num_levels, out_channels=1, kernel_size=kernel_size, padding=padding ), nn.ReLU(inplace=True), ) # use torch.hub to load pretrained model if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(KeyNet_URL, map_location=map_location_to_cpu) self.load_state_dict(pretrained_dict['state_dict'], strict=True) self.eval() def forward(self, x: Tensor) -> Tensor: """ x - input image """ shape_im = x.shape feats: List[Tensor] = [self.feature_extractor(x)] for i in range(1, self.num_levels): x = pyrdown(x, factor=1.2) feats_i = self.feature_extractor(x) feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode='bilinear') feats.append(feats_i) scores = self.last_conv(concatenate(feats, 1)) return scores
[docs]class KeyNetDetector(Module): """Multi-scale feature detector based on KeyNet. This is based on the original code from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for more details. Args: pretrained: Download and set pretrained weights to the model. num_features: Number of features to detect. keynet_conf: Dict with initiliazation parameters. Do not pass it, unless you know what you are doing`. ori_module: for local feature orientation estimation. Default: :class:`~kornia.feature.PassLAF`, which does nothing. See :class:`~kornia.feature.LAFOrienter` for details. aff_module: for local feature affine shape estimation. Default: :class:`~kornia.feature.PassLAF`, which does nothing. See :class:`~kornia.feature.LAFAffineShapeEstimator` for details. """ def __init__( self, pretrained: bool = False, num_features: int = 2048, keynet_conf: KeyNet_conf = keynet_default_config, ori_module: Optional[Module] = None, aff_module: Optional[Module] = None, ): super().__init__() self.model = KeyNet(pretrained, keynet_conf) # Load extraction configuration self.num_pyramid_levels = keynet_conf['pyramid_levels'] self.num_upscale_levels = keynet_conf['up_levels'] self.scale_factor_levels = keynet_conf['scale_factor_levels'] self.mr_size = keynet_conf['s_mult'] self.nms_size = keynet_conf['nms_size'] self.nms = NonMaximaSuppression2d((self.nms_size, self.nms_size)) self.num_features = num_features if ori_module is None: self.ori: Module = PassLAF() else: self.ori = ori_module if aff_module is None: self.aff: Module = PassLAF() else: self.aff = aff_module def remove_borders(self, score_map, borders: int = 15): """It removes the borders of the image to avoid detections on the corners.""" mask = torch.zeros_like(score_map) mask[:, :, borders:-borders, borders:-borders] = 1 return mask * score_map def detect_features_on_single_level( self, level_img: Tensor, num_kp: int, factor: Tuple[float, float] ) -> Tuple[Tensor, Tensor]: det_map = self.nms(self.remove_borders(self.model(level_img))) device = level_img.device dtype = level_img.dtype yx = det_map.nonzero()[:, 2:].t() scores = det_map[0, 0, yx[0], yx[1]] # keynet supports only non-batched images scores_sorted, indices = torch.sort(scores, descending=True) indices = indices[where(scores_sorted > 0.0)] yx = yx[:, indices[:num_kp]].t() current_kp_num = len(yx) xy_projected = yx.view(1, current_kp_num, 2).flip(2) * tensor(factor, device=device, dtype=dtype) scale_factor = 0.5 * (factor[0] + factor[1]) scale = scale_factor * self.mr_size * torch.ones(1, current_kp_num, 1, 1, device=device, dtype=dtype) lafs = laf_from_center_scale_ori(xy_projected, scale, zeros(1, current_kp_num, 1, device=device, dtype=dtype)) return scores_sorted[:num_kp], lafs def detect(self, img: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # Compute points per level num_features_per_level: List[float] = [] tmp = 0.0 factor_points = self.scale_factor_levels**2 levels = self.num_pyramid_levels + self.num_upscale_levels + 1 for idx_level in range(levels): tmp += factor_points ** (-1 * (idx_level - self.num_upscale_levels)) nf = self.num_features * factor_points ** (-1 * (idx_level - self.num_upscale_levels)) num_features_per_level.append(nf) num_features_per_level = list(map(lambda x: int(x / tmp), num_features_per_level)) _, _, h, w = img.shape img_up = img cur_img = img all_responses: List[Tensor] = [] all_lafs: List[Tensor] = [] # Extract features from the upper levels for idx_level in range(self.num_upscale_levels): nf = num_features_per_level[len(num_features_per_level) - self.num_pyramid_levels - 1 - (idx_level + 1)] num_points_level = int(nf) # Resize input image up_factor = self.scale_factor_levels ** (1 + idx_level) nh, nw = int(h * up_factor), int(w * up_factor) up_factor_kpts = (float(w) / float(nw), float(h) / float(nh)) img_up = F.interpolate(img_up, (nh, nw), mode='bilinear', align_corners=False) cur_scores, cur_lafs = self.detect_features_on_single_level(img_up, num_points_level, up_factor_kpts) all_responses.append(cur_scores.view(1, -1)) all_lafs.append(cur_lafs) # Extract features from the downsampling pyramid for idx_level in range(self.num_pyramid_levels + 1): if idx_level > 0: cur_img = pyrdown(cur_img, factor=self.scale_factor_levels) _, _, nh, nw = cur_img.shape factor = (float(w) / float(nw), float(h) / float(nh)) else: factor = (1.0, 1.0) num_points_level = int(num_features_per_level[idx_level]) if idx_level > 0 or (self.num_upscale_levels > 0): nf2 = [num_features_per_level[a] for a in range(0, idx_level + 1 + self.num_upscale_levels)] res_points = Tensor(nf2).sum().item() num_points_level = int(res_points) cur_scores, cur_lafs = self.detect_features_on_single_level(cur_img, num_points_level, factor) all_responses.append(cur_scores.view(1, -1)) all_lafs.append(cur_lafs) responses: Tensor = concatenate(all_responses, 1) lafs: Tensor = concatenate(all_lafs, 1) if lafs.shape[1] > self.num_features: responses, idxs = torch.topk(responses, k=self.num_features, dim=1) lafs = torch.gather(lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3)) return responses, lafs
[docs] def forward(self, img: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: """Three stage local feature detection. First the location and scale of interest points are determined by detect function. Then affine shape and orientation. Args: img: image to extract features with shape [1xCxHxW]. KeyNetDetector does not support batch processing, because the number of detections is different on each image. mask: a mask with weights where to apply the response function. The shape must be the same as the input image. Returns: lafs: shape [1xNx2x3]. Detected local affine frames. responses: shape [1xNx1]. Response function values for corresponding lafs """ if img.shape[0] != 1: raise ValueError("KeyNet supports only single-image input") responses, lafs = self.detect(img, mask) lafs = self.aff(lafs, img) lafs = self.ori(lafs, img) return lafs, responses