# kornia.geometry.plane module inspired by Eigen::geometry::Hyperplane
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h
from typing import Optional
from kornia.core import Module, Tensor, stack, where
from kornia.core.tensor_wrapper import unwrap, wrap
from kornia.geometry.linalg import batched_dot_product
from kornia.geometry.vector import Scalar, Vector3
from kornia.testing import KORNIA_CHECK, KORNIA_CHECK_SHAPE, KORNIA_CHECK_TYPE
from kornia.utils.helpers import _torch_svd_cast
__all__ = ["Hyperplane", "fit_plane"]
def normalized(v: Tensor, eps: float = 1e-6) -> Tensor:
return v / batched_dot_product(v, v).add(eps).sqrt()
[docs]class Hyperplane(Module):
def __init__(self, n: Vector3, d: Scalar) -> None:
super().__init__()
KORNIA_CHECK_TYPE(n, Vector3)
KORNIA_CHECK_TYPE(d, Scalar)
# TODO: fix checkers
# KORNIA_CHECK_SHAPE(n, ["B", "*"])
# KORNIA_CHECK_SHAPE(d, ["B"])
self._n = n
self._d = d
def __str__(self) -> str:
return f"Normal: {self.normal}\nOffset: {self.offset}"
def __repr__(self) -> str:
return str(self)
@property
def normal(self) -> Vector3:
return self._n
@property
def offset(self) -> Scalar:
return self._d
def abs_distance(self, p: Vector3) -> Scalar:
return Scalar(self.signed_distance(p).abs())
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L145
# TODO: tests
def signed_distance(self, p: Vector3) -> Scalar:
KORNIA_CHECK(isinstance(p, (Vector3, Tensor)))
return self.normal.dot(p) + self.offset
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L154
# TODO: tests
def projection(self, p: Vector3) -> Vector3:
dist = self.signed_distance(p)
if len(dist.shape) != len(self.normal):
# non batched plane project a batch of points
dist = dist[..., None] # Nx1
# TODO: TypeError: bad operand type for unary -: 'Scalar'
return p - dist.data * self.normal
# TODO: make that Vector can subtract Scalar
# return p - self.signed_distance(p) * self.normal
@classmethod
def from_vector(self, n: Vector3, e: Vector3) -> "Hyperplane":
normal: Vector3 = n
offset = -normal.dot(e)
return Hyperplane(normal, Scalar(offset))
@classmethod
def through(cls, p0: Tensor, p1: Tensor, p2: Optional[Tensor] = None) -> "Hyperplane":
# 2d case
if p2 is None:
# TODO: improve tests
KORNIA_CHECK_SHAPE(p0, ["*", "2"])
KORNIA_CHECK(p0.shape == p1.shape)
# TODO: implement `.unitOrthonormal`
normal2d = normalized(p1 - p0)
offset2d = -batched_dot_product(p0, normal2d)
return Hyperplane(wrap(normal2d, Vector3), wrap(offset2d, Scalar))
# 3d case
KORNIA_CHECK_SHAPE(p0, ["*", "3"])
KORNIA_CHECK(p0.shape == p1.shape)
KORNIA_CHECK(p1.shape == p2.shape)
v0, v1 = (p2 - p0), (p1 - p0)
normal = v0.cross(v1)
norm = normal.norm(-1)
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L108
def compute_normal_svd(v0, v1):
# NOTE: for reason TensorWrapper does not stack well
m = stack((unwrap(v0), unwrap(v1)), -2) # Bx2x3
_, _, V = _torch_svd_cast(m) # kornia solution lies in the last row
return wrap(V[..., :, -1], Vector3) # Bx3
normal_mask = norm <= v0.norm(-1) * v1.norm(-1) * 1e-6
normal = where(normal_mask, compute_normal_svd(v0, v1), normal / (norm + 1e-6))
offset = -batched_dot_product(p0, normal)
return Hyperplane(wrap(normal, Vector3), wrap(offset, Scalar))
# TODO: factor to avoid duplicated from line.py
# https://github.com/strasdat/Sophus/blob/23.04-beta/cpp/sophus/geometry/fit_plane.h
def fit_plane(points: Vector3) -> Hyperplane:
"""Fit a plane from a set of points using SVD.
Args:
points: tensor containing a batch of sets of n-dimensional points. The expected
shape of the tensor is :math:`(N, D)`.
Return:
The computed hyperplane object.
"""
# TODO: fix to support more type check here
# KORNIA_CHECK_SHAPE(points, ["N", "D"])
if points.shape[-1] != 3:
raise TypeError("vector must be (*, 3)")
mean = points.mean(-2, True)
points_centered = points - mean
# NOTE: not optimal for 2d points, but for now works for other dimensions
_, _, V = _torch_svd_cast(points_centered)
# the first left eigenvector is the direction on the fited line
direction = V[..., :, -1] # BxD
origin = mean[..., 0, :] # BxD
return Hyperplane.from_vector(Vector3(direction), Vector3(origin))