"""Module containing RANSAC modules."""
import math
from typing import Callable, Optional, Tuple
import torch
from kornia.core import Device, Module, Tensor, zeros
from kornia.geometry import (
find_fundamental,
find_homography_dlt,
find_homography_dlt_iterated,
find_homography_lines_dlt,
find_homography_lines_dlt_iterated,
symmetrical_epipolar_distance,
)
from kornia.geometry.homography import (
line_segment_transfer_error_one_way,
oneway_transfer_error,
sample_is_valid_for_homography,
)
from kornia.testing import KORNIA_CHECK_SHAPE
__all__ = ["RANSAC"]
[docs]class RANSAC(Module):
"""Module for robust geometry estimation with RANSAC. https://en.wikipedia.org/wiki/Random_sample_consensus.
Args:
model_type: type of model to estimate, e.g. "homography" or "fundamental".
inliers_threshold: threshold for the correspondence to be an inlier.
batch_size: number of generated samples at once.
max_iterations: maximum batches to generate. Actual number of models to try is ``batch_size * max_iterations``.
confidence: desired confidence of the result, used for the early stopping.
max_local_iterations: number of local optimization (polishing) iterations.
"""
supported_models = ['homography', 'fundamental', 'homography_from_linesegments']
def __init__(
self,
model_type: str = 'homography',
inl_th: float = 2.0,
batch_size: int = 2048,
max_iter: int = 10,
confidence: float = 0.99,
max_lo_iters: int = 5,
):
super().__init__()
self.inl_th = inl_th
self.max_iter = max_iter
self.batch_size = batch_size
self.model_type = model_type
self.confidence = confidence
self.max_lo_iters = max_lo_iters
self.model_type = model_type
self.error_fn: Callable[..., Tensor]
self.minimal_solver: Callable[..., Tensor]
self.polisher_solver: Callable[..., Tensor]
if model_type == 'homography':
self.error_fn = oneway_transfer_error
self.minimal_solver = find_homography_dlt
self.polisher_solver = find_homography_dlt_iterated
self.minimal_sample_size = 4
elif model_type == 'homography_from_linesegments':
self.error_fn = line_segment_transfer_error_one_way
self.minimal_solver = find_homography_lines_dlt
self.polisher_solver = find_homography_lines_dlt_iterated
self.minimal_sample_size = 4
elif model_type == 'fundamental':
self.error_fn = symmetrical_epipolar_distance
self.minimal_solver = find_fundamental
self.minimal_sample_size = 8
# ToDo: implement 7pt solver instead of 8pt minimal_solver
# https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fundam.cpp#L498
self.polisher_solver = find_fundamental
else:
raise NotImplementedError(f"{model_type} is unknown. Try one of {self.supported_models}")
def sample(self, sample_size: int, pop_size: int, batch_size: int, device: Device = torch.device('cpu')) -> Tensor:
"""Minimal sampler, but unlike traditional RANSAC we sample in batches to get benefit of the parallel
processing, esp.
on GPU.
"""
rand = torch.rand(batch_size, pop_size, device=device)
_, out = rand.topk(k=sample_size, dim=1)
return out
@staticmethod
def max_samples_by_conf(n_inl: int, num_tc: int, sample_size: int, conf: float) -> float:
"""Formula to update max_iter in order to stop iterations earlier
https://en.wikipedia.org/wiki/Random_sample_consensus."""
if n_inl == num_tc:
return 1.0
return math.log(1.0 - conf) / math.log(1.0 - math.pow(n_inl / num_tc, sample_size))
def estimate_model_from_minsample(self, kp1: Tensor, kp2: Tensor) -> Tensor:
batch_size, sample_size = kp1.shape[:2]
H = self.minimal_solver(kp1, kp2, torch.ones(batch_size, sample_size, dtype=kp1.dtype, device=kp1.device))
return H
def verify(self, kp1: Tensor, kp2: Tensor, models: Tensor, inl_th: float) -> Tuple[Tensor, Tensor, float]:
if len(kp1.shape) == 2:
kp1 = kp1[None]
if len(kp2.shape) == 2:
kp2 = kp2[None]
batch_size = models.shape[0]
if self.model_type == 'homography_from_linesegments':
errors = self.error_fn(kp1.expand(batch_size, -1, 2, 2), kp2.expand(batch_size, -1, 2, 2), models)
else:
errors = self.error_fn(kp1.expand(batch_size, -1, 2), kp2.expand(batch_size, -1, 2), models)
inl = errors <= inl_th
models_score = inl.to(kp1).sum(dim=1)
best_model_idx = models_score.argmax()
best_model_score = models_score[best_model_idx].item()
model_best = models[best_model_idx].clone()
inliers_best = inl[best_model_idx]
return model_best, inliers_best, best_model_score
def remove_bad_samples(self, kp1: Tensor, kp2: Tensor) -> Tuple[Tensor, Tensor]:
""""""
# ToDo: add (model-specific) verification of the samples,
# E.g. constraints on not to be a degenerate sample
if self.model_type == 'homography':
mask = sample_is_valid_for_homography(kp1, kp2)
return kp1[mask], kp2[mask]
return kp1, kp2
def remove_bad_models(self, models: Tensor) -> Tensor:
# ToDo: add more and better degenerate model rejection
# For now it is simple and hardcoded
main_diagonal = torch.diagonal(models, dim1=1, dim2=2)
mask = main_diagonal.abs().min(dim=1)[0] > 1e-4
return models[mask]
def polish_model(self, kp1: Tensor, kp2: Tensor, inliers: Tensor) -> Tensor:
# TODO: Replace this with MAGSAC++ polisher
kp1_inl = kp1[inliers][None]
kp2_inl = kp2[inliers][None]
num_inl = kp1_inl.size(1)
model = self.polisher_solver(
kp1_inl, kp2_inl, torch.ones(1, num_inl, dtype=kp1_inl.dtype, device=kp1_inl.device)
)
return model
def validate_inputs(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> None:
if self.model_type in ['homography', 'fundamental']:
KORNIA_CHECK_SHAPE(kp1, ["N", "2"])
KORNIA_CHECK_SHAPE(kp2, ["N", "2"])
if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
raise ValueError(
f"kp1 and kp2 should be \
equal shape at at least [{self.minimal_sample_size}, 2], \
got {kp1.shape}, {kp2.shape}"
)
if self.model_type == 'homography_from_linesegments':
KORNIA_CHECK_SHAPE(kp1, ["N", "2", "2"])
KORNIA_CHECK_SHAPE(kp2, ["N", "2", "2"])
if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
raise ValueError(
f"kp1 and kp2 should be \
equal shape at at least [{self.minimal_sample_size}, 2, 2], \
got {kp1.shape}, {kp2.shape}"
)
[docs] def forward(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
r"""Main forward method to execute the RANSAC algorithm.
Args:
kp1: source image keypoints :math:`(N, 2)`.
kp2: distance image keypoints :math:`(N, 2)`.
weights: optional correspondences weights. Not used now.
Returns:
- Estimated model, shape of :math:`(1, 3, 3)`.
- The inlier/outlier mask, shape of :math:`(1, N)`, where N is number of input correspondences.
"""
self.validate_inputs(kp1, kp2, weights)
best_score_total: float = float(self.minimal_sample_size)
num_tc: int = len(kp1)
best_model_total = zeros(3, 3, dtype=kp1.dtype, device=kp1.device)
inliers_best_total: Tensor = zeros(num_tc, 1, device=kp1.device, dtype=torch.bool)
for i in range(self.max_iter):
# Sample minimal samples in batch to estimate models
idxs = self.sample(self.minimal_sample_size, num_tc, self.batch_size, kp1.device)
kp1_sampled = kp1[idxs]
kp2_sampled = kp2[idxs]
kp1_sampled, kp2_sampled = self.remove_bad_samples(kp1_sampled, kp2_sampled)
if len(kp1_sampled) == 0:
continue
# Estimate models
models = self.estimate_model_from_minsample(kp1_sampled, kp2_sampled)
models = self.remove_bad_models(models)
if (models is None) or (len(models) == 0):
continue
# Score the models and select the best one
model, inliers, model_score = self.verify(kp1, kp2, models, self.inl_th)
# Store far-the-best model and (optionally) do a local optimization
if model_score > best_score_total:
# Local optimization
for lo_step in range(self.max_lo_iters):
model_lo = self.polish_model(kp1, kp2, inliers)
if (model_lo is None) or (len(model_lo) == 0):
continue
_, inliers_lo, score_lo = self.verify(kp1, kp2, model_lo, self.inl_th)
# print (f"Orig score = {best_model_score}, LO score = {score_lo} TC={num_tc}")
if score_lo > model_score:
model = model_lo.clone()[0]
inliers = inliers_lo.clone()
model_score = score_lo
else:
break
# Now storing the best model
best_model_total = model.clone()
inliers_best_total = inliers.clone()
best_score_total = model_score
# Should we already stop?
new_max_iter = int(
self.max_samples_by_conf(int(best_score_total), num_tc, self.minimal_sample_size, self.confidence)
)
# print (f"New max_iter = {new_max_iter}")
# Stop estimation, if the model is very good
if (i + 1) * self.batch_size >= new_max_iter:
break
# local optimization with all inliers for better precision
return best_model_total, inliers_best_total