Source code for kornia.feature.sold2.sold2

from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn.functional as F

from kornia.core import Module, Tensor, concatenate, pad, stack
from kornia.geometry.conversions import normalize_pixel_coordinates
from kornia.testing import KORNIA_CHECK_SHAPE
from kornia.utils import map_location_to_cpu

from .backbones import SOLD2Net
from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions

urls: Dict[str, str] = {}
urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"


default_cfg: Dict[str, Any] = {
    'backbone_cfg': {'input_channel': 1, 'depth': 4, 'num_stacks': 2, 'num_blocks': 1, 'num_classes': 5},
    'use_descriptor': True,
    'grid_size': 8,
    'keep_border_valid': True,
    'detection_thresh': 0.0153846,  # = 1/65: threshold of junction detection
    'max_num_junctions': 500,  # maximum number of junctions per image
    'line_detector_cfg': {
        'detect_thresh': 0.5,
        'num_samples': 64,
        'inlier_thresh': 0.99,
        'use_candidate_suppression': True,
        'nms_dist_tolerance': 3.0,
        'use_heatmap_refinement': True,
        'heatmap_refine_cfg': {
            'mode': "local",
            'ratio': 0.2,
            'valid_thresh': 0.001,
            'num_blocks': 20,
            'overlap_ratio': 0.5,
        },
        'use_junction_refinement': True,
        'junction_refine_cfg': {'num_perturbs': 9, 'perturb_interval': 0.25},
    },
    'line_matcher_cfg': {
        'cross_check': True,
        'num_samples': 5,
        'min_dist_pts': 8,
        'top_k_candidates': 10,
        'grid_size': 4,
    },
}


[docs]class SOLD2(Module): r"""Module, which detects and describe line segments in an image. This is based on the original code from the paper "SOLD²: Self-supervised Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details. Args: config: Dict specifying parameters. None will load the default parameters, which are tuned for images in the range 400~800 px. pretrained: If True, download and set pretrained weights to the model. Returns: The raw junction and line heatmaps, the semi-dense descriptor map, as well as the list of detected line segments (ij coordinates convention). Example: >>> images = torch.rand(2, 1, 512, 512) >>> sold2 = SOLD2() >>> outputs = sold2(images) >>> line_seg1 = outputs["line_segments"][0] >>> line_seg2 = outputs["line_segments"][1] >>> desc1 = outputs["dense_desc"][0] >>> desc2 = outputs["dense_desc"][1] >>> matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None]) """ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None): super().__init__() # Initialize some parameters self.config = default_cfg if config is None else config self.grid_size = self.config["grid_size"] self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65) self.max_num_junctions = self.config.get("max_num_junctions", 500) # Load the pre-trained model self.model = SOLD2Net(self.config) if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu) state_dict = self.adapt_state_dict(pretrained_dict['model_state_dict']) self.model.load_state_dict(state_dict) self.eval() # Initialize the line detector self.line_detector_cfg = self.config["line_detector_cfg"] self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"]) # Initialize the line matcher self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"])
[docs] def forward(self, img: Tensor) -> Dict[str, Any]: """ Args: img: batched images with shape :math:`(B, 1, H, W)`. Return: - ``line_segments``: list of N line segments in each of the B images :math:`List[(N, 2, 2)]`. - ``junction_heatmap``: raw junction heatmap of shape :math:`(B, H, W)`. - ``line_heatmap``: raw line heatmap of shape :math:`(B, H, W)`. - ``dense_desc``: the semi-dense descriptor map of shape :math:`(B, 128, H/4, W/4)`. """ KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"]) outputs = {} # Forward pass of the CNN backbone net_outputs = self.model(img) outputs["junction_heatmap"] = net_outputs["junctions"] outputs["line_heatmap"] = net_outputs["heatmap"] outputs["dense_desc"] = net_outputs["descriptors"] # Loop through all images lines = [] for junc_prob, heatmap in zip(net_outputs["junctions"], net_outputs["heatmap"]): # Get the junctions junctions = prob_to_junctions(junc_prob, self.grid_size, self.junc_detect_thresh, self.max_num_junctions) # Run the line detector line_map, junctions, _ = self.line_detector.detect(junctions, heatmap) lines.append(line_map_to_segments(junctions, line_map)) outputs["line_segments"] = lines return outputs
def match(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor: """Find the best matches between two sets of line segments and their corresponding descriptors. Args: line_seg1, line_seg2: list of line segments in two images, with shape [num_lines, 2, 2]. desc1, desc2: semi-dense descriptor maps of the images, with shape [1, 128, H/4, W/4]. Returns: A np.array of size [num_lines1] indicating the index in line_seg2 of the matched line, for each line in line_seg1. -1 means that the line is not matched. """ return self.line_matcher(line_seg1, line_seg2, desc1, desc2) def adapt_state_dict(self, state_dict): del state_dict["w_junc"] del state_dict["w_heatmap"] del state_dict["w_desc"] state_dict["heatmap_decoder.conv_block_lst.2.0.weight"] = state_dict["heatmap_decoder.conv_block_lst.2.weight"] state_dict["heatmap_decoder.conv_block_lst.2.0.bias"] = state_dict["heatmap_decoder.conv_block_lst.2.bias"] del state_dict["heatmap_decoder.conv_block_lst.2.weight"] del state_dict["heatmap_decoder.conv_block_lst.2.bias"] return state_dict
class WunschLineMatcher(Module): """Class matching two sets of line segments with the Needleman-Wunsch algorithm. TODO: move it later in kornia.feature.matching """ def __init__( self, cross_check: bool = True, num_samples: int = 10, min_dist_pts: int = 8, top_k_candidates: int = 10, grid_size: int = 8, line_score: bool = False, ): super().__init__() self.cross_check = cross_check self.num_samples = num_samples self.min_dist_pts = min_dist_pts self.top_k_candidates = top_k_candidates self.grid_size = grid_size self.line_score = line_score # True to compute saliency on a line def forward(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor: """Find the best matches between two sets of line segments and their corresponding descriptors.""" KORNIA_CHECK_SHAPE(line_seg1, ["N", "2", "2"]) KORNIA_CHECK_SHAPE(line_seg2, ["N", "2", "2"]) KORNIA_CHECK_SHAPE(desc1, ["B", "D", "H", "H"]) KORNIA_CHECK_SHAPE(desc2, ["B", "D", "H", "H"]) device = desc1.device img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) # Default case when an image has no lines if len(line_seg1) == 0: return torch.empty(0, dtype=torch.int, device=device) if len(line_seg2) == 0: return -torch.ones(len(line_seg1), dtype=torch.int, device=device) # Sample points regularly along each line line_points1, valid_points1 = self.sample_line_points(line_seg1) line_points2, valid_points2 = self.sample_line_points(line_seg2) line_points1 = line_points1.reshape(-1, 2) line_points2 = line_points2.reshape(-1, 2) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) grid2 = keypoints_to_grid(line_points2, img_size2) desc1 = F.normalize(F.grid_sample(desc1, grid1, align_corners=False)[0, :, :, 0], dim=0) desc2 = F.normalize(F.grid_sample(desc2, grid2, align_corners=False)[0, :, :, 0], dim=0) # Precompute the distance between line points for every pair of lines # Assign a score of -1 for invalid points scores = desc1.t() @ desc2 scores[~valid_points1.flatten()] = -1 scores[:, ~valid_points2.flatten()] = -1 scores = scores.reshape(len(line_seg1), self.num_samples, len(line_seg2), self.num_samples) scores = scores.permute(0, 2, 1, 3) # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) # Pre-filter the line candidates and find the best match for each line matches = self.filter_and_match_lines(scores) # [Optionally] filter matches with mutual nearest neighbor filtering if self.cross_check: matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2)) mutual = matches2[matches] == torch.arange(len(line_seg1), device=device) matches[~mutual] = -1 return matches def sample_line_points(self, line_seg: Tensor) -> Tuple[Tensor, Tensor]: """Regularly sample points along each line segments, with a minimal distance between each point. Pad the remaining points. Inputs: line_seg: an Nx2x2 Tensor. Outputs: line_points: an N x num_samples x 2 Tensor. valid_points: a boolean N x num_samples Tensor. """ KORNIA_CHECK_SHAPE(line_seg, ["N", "2", "2"]) num_lines = len(line_seg) line_lengths = torch.norm(line_seg[:, 0] - line_seg[:, 1], dim=1) # Sample the points separated by at least min_dist_pts along each line # The number of samples depends on the length of the line num_samples_lst = torch.clamp( torch.div(line_lengths, self.min_dist_pts, rounding_mode='floor'), 2, self.num_samples ).int() line_points = torch.empty((num_lines, self.num_samples, 2), dtype=torch.float) valid_points = torch.empty((num_lines, self.num_samples), dtype=torch.bool) for n_samp in range(2, self.num_samples + 1): # Consider all lines where we can fit up to n_samp points cur_mask = num_samples_lst == n_samp cur_line_seg = line_seg[cur_mask] line_points_x = batched_linspace(cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n_samp, dim=-1) line_points_y = batched_linspace(cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n_samp, dim=-1) cur_line_points = stack([line_points_x, line_points_y], -1) # Pad cur_line_points = pad(cur_line_points, (0, 0, 0, self.num_samples - n_samp)) cur_valid_points = torch.ones(len(cur_line_seg), self.num_samples, dtype=torch.bool) cur_valid_points[:, n_samp:] = False line_points[cur_mask] = cur_line_points valid_points[cur_mask] = cur_valid_points return line_points, valid_points def filter_and_match_lines(self, scores: Tensor) -> Tensor: """Use the scores to keep the top k best lines, compute the Needleman- Wunsch algorithm on each candidate pairs, and keep the highest score. Inputs: scores: a (N, M, n, n) Tensor containing the pairwise scores of the elements to match. Outputs: matches: a (N) Tensor containing the indices of the best match """ KORNIA_CHECK_SHAPE(scores, ["M", "N", "n", "n"]) # Pre-filter the pairs and keep the top k best candidate lines line_scores1 = scores.max(3)[0] valid_scores1 = line_scores1 != -1 line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2) line_scores2 = scores.max(2)[0] valid_scores2 = line_scores2 != -1 line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2) line_scores = (line_scores1 + line_scores2) / 2 topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :] # topk_lines.shape = (n_lines1, top_k_candidates) top_scores = torch.take_along_dim(scores, topk_lines[:, :, None, None], dim=1) # Consider the reversed line segments as well top_scores = concatenate([top_scores, torch.flip(top_scores, dims=[-1])], 1) # Compute the line distance matrix with Needleman-Wunsch algo and # retrieve the closest line neighbor n_lines1, top2k, n, m = top_scores.shape top_scores = top_scores.reshape((n_lines1 * top2k, n, m)) nw_scores = self.needleman_wunsch(top_scores) nw_scores = nw_scores.reshape(n_lines1, top2k) matches = torch.remainder(torch.argmax(nw_scores, dim=1), top2k // 2) matches = topk_lines[torch.arange(n_lines1), matches] return matches def needleman_wunsch(self, scores: Tensor) -> Tensor: """Batched implementation of the Needleman-Wunsch algorithm. The cost of the InDel operation is set to 0 by subtracting the gap penalty to the scores. Inputs: scores: a (B, N, M) Tensor containing the pairwise scores of the elements to match. """ KORNIA_CHECK_SHAPE(scores, ["B", "N", "M"]) b, n, m = scores.shape # Recalibrate the scores to get a gap score of 0 gap = 0.1 nw_scores = scores - gap # Run the dynamic programming algorithm nw_grid = torch.zeros(b, n + 1, m + 1, dtype=torch.float) for i in range(n): for j in range(m): nw_grid[:, i + 1, j + 1] = torch.maximum( torch.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]), nw_grid[:, i, j] + nw_scores[:, i, j] ) return nw_grid[:, -1, -1] def keypoints_to_grid(keypoints: Tensor, img_size: Tuple[int, int]) -> Tensor: """Convert a list of keypoints into a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate. Args: keypoints: a tensor [N, 2] of N keypoints (ij coordinates convention). img_size: the original image size (H, W) """ KORNIA_CHECK_SHAPE(keypoints, ["N", "2"]) n_points = len(keypoints) grid_points = normalize_pixel_coordinates(keypoints[:, [1, 0]], img_size[0], img_size[1]) grid_points = grid_points.view(-1, n_points, 1, 2) return grid_points def batched_linspace(start, end, step, dim): """Batch version of torch.normalize (similar to the numpy one).""" intervals = ((end - start) / (step - 1)).unsqueeze(dim) broadcast_size = [1] * len(intervals.shape) broadcast_size[dim] = step samples = torch.arange(step, dtype=torch.float, device=start.device).reshape(broadcast_size) samples = start.unsqueeze(dim) + samples * intervals return samples