from typing import List, Optional, Tuple, Union, cast
import torch
from kornia.core import Tensor
from kornia.geometry.bbox import validate_bbox
from kornia.geometry.linalg import transform_points
from kornia.utils import eye_like
__all__ = ["Boxes", "Boxes3D"]
def _is_floating_point_dtype(dtype: torch.dtype) -> bool:
return dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.half)
def _merge_box_list(boxes: List[torch.Tensor], method: str = "pad") -> Tuple[torch.Tensor, List[int]]:
r"""Merge a list of boxes into one tensor."""
if not all(box.shape[-2:] == torch.Size([4, 2]) and box.dim() == 3 for box in boxes):
raise TypeError(f"Input boxes must be a list of (N, 4, 2) shaped. Got: {[box.shape for box in boxes]}.")
if method == "pad":
max_N = max(box.shape[0] for box in boxes)
stats = [max_N - box.shape[0] for box in boxes]
output = torch.nn.utils.rnn.pad_sequence(boxes, batch_first=True)
else:
raise NotImplementedError(f"`{method}` is not implemented.")
return output, stats
def _transform_boxes(boxes: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
"""Transforms 3D and 2D in kornia format by applying the transformation matrix M. Boxes and the transformation
matrix could be batched or not.
Args:
boxes: 2D quadrilaterals or 3D hexahedrons in kornia format.
M: the transformation matrix of shape :math:`(3, 3)` or :math:`(B, 3, 3)` for 2D and :math:`(4, 4)` or
:math:`(B, 4, 4)` for 3D hexahedron.
"""
M = M if M.is_floating_point() else M.float()
# Work with batch as kornia.transform_points only supports a batch of points.
boxes_per_batch, n_points_per_box, coordinates_dimension = boxes.shape[-3:]
points = boxes.view(-1, n_points_per_box * boxes_per_batch, coordinates_dimension)
M = M if M.ndim == 3 else M.unsqueeze(0)
if points.shape[0] != M.shape[0]:
raise ValueError(
f"Batch size mismatch. Got {points.shape[0]} for boxes and {M.shape[0]} for the transformation matrix."
)
transformed_boxes: torch.Tensor = transform_points(M, points)
transformed_boxes = transformed_boxes.view_as(boxes)
return transformed_boxes
def _boxes_to_polygons(
xmin: torch.Tensor, ymin: torch.Tensor, width: torch.Tensor, height: torch.Tensor
) -> torch.Tensor:
if not xmin.ndim == ymin.ndim == width.ndim == height.ndim == 2:
raise ValueError("We expect to create a batch of 2D boxes (quadrilaterals) in vertices format (B, N, 4, 2)")
# Create (B,N,4,2) with all points in top left position of boxes
polygons = torch.zeros((xmin.shape[0], xmin.shape[1], 4, 2), device=xmin.device, dtype=xmin.dtype)
polygons[..., 0] = xmin.unsqueeze(-1)
polygons[..., 1] = ymin.unsqueeze(-1)
# Shift top-right, bottom-right, bottom-left points to the right coordinates
polygons[..., 1, 0] += width - 1 # Top right
polygons[..., 2, 0] += width - 1 # Bottom right
polygons[..., 2, 1] += height - 1 # Bottom right
polygons[..., 3, 1] += height - 1 # Bottom left
return polygons
def _boxes_to_quadrilaterals(boxes: torch.Tensor, mode: str = "xyxy", validate_boxes: bool = True) -> torch.Tensor:
"""Convert from boxes to quadrilaterals."""
mode = mode.lower()
if mode.startswith("vertices"):
batched = boxes.ndim == 4
if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == torch.Size([4, 2])):
raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2) when {mode} mode. Got {boxes.shape}.")
elif mode.startswith("xy"):
batched = boxes.ndim == 3
if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 4):
raise ValueError(f"Boxes shape must be (N, 4) or (B, N, 4) when {mode} mode. Got {boxes.shape}.")
else:
raise ValueError(f"Unknown mode {mode}")
boxes = boxes if boxes.is_floating_point() else boxes.float()
boxes = boxes if batched else boxes.unsqueeze(0)
if mode.startswith("vertices"):
if mode == "vertices":
quadrilaterals = boxes.clone()
# Here, vertices are quadrilaterals with width and height defined as `width = xmax - xmin` and
# `height = ymax - ymin`. We need to convert to `width = xmax - xmin + 1` and `height = ymax - ymin + 1` to
# match with internal Boxes Kornia representation.
quadrilaterals[..., 1:3, 0] = quadrilaterals[..., 1:3, 0] - 1
quadrilaterals[..., 2:, 1] = quadrilaterals[..., 2:, 1] - 1
elif mode == "vertices_plus":
# Avoid passing reference
quadrilaterals = boxes.clone()
else:
raise ValueError(f"Unknown mode {mode}")
not validate_boxes or validate_bbox(quadrilaterals)
elif mode.startswith("xy"):
if mode == "xyxy":
height, width = boxes[..., 3] - boxes[..., 1], boxes[..., 2] - boxes[..., 0]
elif mode == "xyxy_plus":
height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
elif mode == "xywh":
height, width = boxes[..., 3], boxes[..., 2]
else:
raise ValueError(f"Unknown mode {mode}")
if validate_boxes:
if (width <= 0).any():
raise ValueError("Some boxes have negative widths or 0.")
if (height <= 0).any():
raise ValueError("Some boxes have negative heights or 0.")
xmin, ymin = boxes[..., 0], boxes[..., 1]
quadrilaterals = _boxes_to_polygons(xmin, ymin, width, height)
else:
raise ValueError(f"Unknown mode {mode}")
quadrilaterals = quadrilaterals if batched else quadrilaterals.squeeze(0)
return quadrilaterals
def _boxes3d_to_polygons3d(
xmin: torch.Tensor,
ymin: torch.Tensor,
zmin: torch.Tensor,
width: torch.Tensor,
height: torch.Tensor,
depth: torch.Tensor,
) -> torch.Tensor:
if not xmin.ndim == ymin.ndim == zmin.ndim == width.ndim == height.ndim == depth.ndim == 2:
raise ValueError("We expect to create a batch of 3D boxes (hexahedrons) in vertices format (B, N, 8, 3)")
# Front
# Create (B,N,4,3) with all points in front top left position of boxes
front_vertices = torch.zeros((xmin.shape[0], xmin.shape[1], 4, 3), device=xmin.device, dtype=xmin.dtype)
front_vertices[..., 0] = xmin.unsqueeze(-1)
front_vertices[..., 1] = ymin.unsqueeze(-1)
front_vertices[..., 2] = zmin.unsqueeze(-1)
# Shift front-top-right, front-bottom-right, front-bottom-left points to the right coordinates
front_vertices[..., 1, 0] += width - 1 # Top right
front_vertices[..., 2, 0] += width - 1 # Bottom right
front_vertices[..., 2, 1] += height - 1 # Bottom right
front_vertices[..., 3, 1] += height - 1 # Bottom left
# Back
back_vertices = front_vertices.clone()
back_vertices[..., 2] += depth.unsqueeze(-1) - 1
polygons3d = torch.cat([front_vertices, back_vertices], dim=-2)
return polygons3d
# NOTE: Cannot jit with Union types with torch <= 0.10
# @torch.jit.script
[docs]class Boxes:
r"""2D boxes containing N or BxN boxes.
Args:
boxes: 2D boxes, shape of :math:`(N, 4, 2)`, :math:`(B, N, 4, 2)` or a list of :math:`(N, 4, 2)`.
See below for more details.
raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a floating
point tensor. True to raise an error when `boxes` isn't a floating point tensor, False to cast to float.
mode: the box format of the input boxes.
Note:
**2D boxes format** is defined as a floating data type tensor of shape ``Nx4x2`` or ``BxNx4x2``
where each box is a `quadrilateral <https://en.wikipedia.org/wiki/Quadrilateral>`_ defined by it's 4 vertices
coordinates (A, B, C, D). Coordinates must be in ``x, y`` order. The height and width of a box is defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. Examples of
`quadrilaterals <https://en.wikipedia.org/wiki/Quadrilateral>`_ are rectangles, rhombus and trapezoids.
"""
def __init__(
self,
boxes: Union[torch.Tensor, List[torch.Tensor]],
raise_if_not_floating_point: bool = True,
mode: str = "vertices_plus",
) -> None:
self._N: Optional[List[int]] = None
if isinstance(boxes, list):
boxes, self._N = _merge_box_list(boxes)
if not isinstance(boxes, torch.Tensor):
raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
if not boxes.is_floating_point():
if raise_if_not_floating_point:
raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}")
boxes = boxes.float()
if len(boxes.shape) == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
boxes = boxes.reshape((-1, 4))
if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (4, 2)):
raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2). Got {boxes.shape}.")
self._is_batched = False if boxes.ndim == 3 else True
self._data = boxes
self._mode = mode
[docs] def get_boxes_shape(self) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Compute boxes heights and widths.
Returns:
- Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
- Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
Example:
>>> boxes_xyxy = torch.tensor([[[1,1,2,2],[1,1,3,2]]])
>>> boxes = Boxes.from_tensor(boxes_xyxy)
>>> boxes.get_boxes_shape()
(tensor([[1., 1.]]), tensor([[1., 2.]]))
"""
boxes_xywh = cast(torch.Tensor, self.to_tensor("xywh", as_padded_sequence=True))
widths, heights = boxes_xywh[..., 2], boxes_xywh[..., 3]
return heights, widths
[docs] def merge(self, boxes: "Boxes", inplace: bool = False) -> "Boxes":
"""Merges boxes.
Say, current instance holds :math:`(B, N, 4, 2)` and the incoming boxes holds :math:`(B, M, 4, 2)`,
the merge results in :math:`(B, N + M, 4, 2)`.
Args:
boxes: 2D boxes.
inplace: do transform in-place and return self.
"""
data = torch.cat([self._data, boxes.data], dim=1)
if inplace:
self._data = data
return self
return Boxes(data, False)
def clamp(
self,
topleft: Optional[Union[Tensor, Tuple[int, int]]] = None,
botright: Optional[Union[Tensor, Tuple[int, int]]] = None,
inplace: bool = False,
) -> "Boxes":
""""""
if not (isinstance(topleft, Tensor) and isinstance(botright, Tensor)):
raise NotImplementedError
if inplace:
_data = self._data
else:
_data = self._data.clone()
topleft_x = topleft[:, None, :1].repeat(1, _data.size(1), 4)
_data[..., 0][_data[..., 0] < topleft_x] = topleft_x[_data[..., 0] < topleft_x]
topleft_y = topleft[:, None, 1:].repeat(1, _data.size(1), 4)
_data[..., 1][_data[..., 1] < topleft_y] = topleft_y[_data[..., 1] < topleft_y]
botright_x = botright[:, None, :1].repeat(1, _data.size(1), 4)
_data[..., 0][_data[..., 0] > botright_x] = botright_x[_data[..., 0] > botright_x]
botright_y = botright[:, None, 1:].repeat(1, _data.size(1), 4)
_data[..., 1][_data[..., 1] > botright_y] = botright_y[_data[..., 1] > botright_y]
if inplace:
return self
return Boxes(_data, False)
[docs] def trim(self, correspondence_preserve: bool = False, inplace: bool = False) -> "Boxes":
"""Trim out zero padded boxes.
Given box arrangements of shape :math:`(4, 4, Box)`:
-- Box -- Box -- Box -- Box --
-- 0 -- 0 -- Box -- Box --
-- 0 -- Box -- 0 -- 0 --
-- 0 -- 0 -- 0 -- 0 --
Nothing will change if correspondence_preserve is True. Only pure zero layers will be
removed, resulting in shape :math:`(4, 3, Box)`:
-- Box -- Box -- Box -- Box --
-- 0 -- 0 -- Box -- Box --
-- 0 -- Box -- 0 -- 0 --
Otherwise, you will get :math:`(4, 2, Box)`:
-- Box -- Box -- Box -- Box --
-- 0 -- Box -- Box -- Box --
"""
raise NotImplementedError
def filter_boxes_by_area(
self, min_area: Optional[float] = None, max_area: Optional[float] = None, inplace: bool = False
) -> "Boxes":
area = self.compute_area()
if inplace:
_data = self._data
else:
_data = self._data.clone()
if min_area is not None:
_data[area < min_area] = 0.0
if max_area is not None:
_data[area > max_area] = 0.0
if inplace:
return self
return Boxes(_data, False)
[docs] def compute_area(self) -> torch.Tensor:
"""Returns :math:`(B, N)`."""
w = self._data[:, :, 1, 0] - self._data[:, :, 0, 0]
h = self._data[:, :, 2, 1] - self._data[:, :, 0, 1]
return w * h
[docs] @classmethod
def from_tensor(
cls, boxes: Union[torch.Tensor, List[torch.Tensor]], mode: str = "xyxy", validate_boxes: bool = True
) -> "Boxes":
r"""Helper method to easily create :class:`Boxes` from boxes stored in another format.
Args:
boxes: 2D boxes, shape of :math:`(N, 4)`, :math:`(B, N, 4)`, :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
mode: The format in which the boxes are provided.
* 'xyxy': boxes are assumed to be in the format ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin``
and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
* 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
* 'xywh': boxes are assumed to be in the format ``xmin, ymin, width, height`` where
``width = xmax - xmin`` and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
* 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
*top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
* 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width
and height >= 1 (>= 2 when mode ends with '_plus' suffix).
Returns:
:class:`Boxes` class containing the original `boxes` in the format specified by ``mode``.
Examples:
>>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
>>> boxes = Boxes.from_tensor(boxes_xyxy, mode='xyxy')
>>> boxes.data # (2, 4, 2)
tensor([[[0., 3.],
[0., 3.],
[0., 3.],
[0., 3.]],
<BLANKLINE>
[[5., 1.],
[7., 1.],
[7., 3.],
[5., 3.]]])
"""
quadrilaterals: Union[torch.Tensor, List[torch.Tensor]]
if isinstance(boxes, (torch.Tensor)):
quadrilaterals = _boxes_to_quadrilaterals(boxes, mode=mode, validate_boxes=validate_boxes)
else:
quadrilaterals = [_boxes_to_quadrilaterals(box, mode, validate_boxes) for box in boxes]
# Due to some torch.jit.script bug (at least <= 1.9), you need to pass all arguments to __init__ when
# constructing the class from inside of a method.
return cls(quadrilaterals, False, mode)
[docs] def to_tensor(
self, mode: Optional[str] = None, as_padded_sequence: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
r"""Cast :class:`Boxes` to a tensor. ``mode`` controls which 2D boxes format should be use to represent
boxes in the tensor.
Args:
mode: the output box format. It could be:
* 'xyxy': boxes are defined as ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin`` and
``height = ymax - ymin``.
* 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
* 'xywh': boxes are defined as ``xmin, ymin, width, height`` where ``width = xmax - xmin``
and ``height = ymax - ymin``.
* 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
*top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
* 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
as_padded_sequence: whether to keep the pads for a list of boxes. This parameter is only valid
if the boxes are from a box list.
Returns:
Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
* 'vertices' or 'verticies_plus': :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
* Any other value: :math:`(N, 4)` or :math:`(B, N, 4)`.
Examples:
>>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
>>> boxes = Boxes.from_tensor(boxes_xyxy)
>>> assert (boxes_xyxy == boxes.to_tensor(mode='xyxy')).all()
"""
batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
boxes: Union[torch.Tensor, List[torch.Tensor]]
# Create boxes in xyxy_plus format.
boxes = torch.stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
batched_boxes.shape[0], batched_boxes.shape[1], 4
)
if mode is None:
mode = self.mode
mode = mode.lower()
if mode in ("xyxy", "xyxy_plus"):
pass
elif mode in ("xywh", "vertices", "vertices_plus"):
height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
boxes[..., 2] = width
boxes[..., 3] = height
else:
raise ValueError(f"Unknown mode {mode}")
if mode in ("xyxy", "vertices"):
offset = torch.as_tensor([0, 0, 1, 1], device=boxes.device, dtype=boxes.dtype)
boxes = boxes + offset
if mode.startswith('vertices'):
boxes = _boxes_to_polygons(boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3])
if self._N is not None and not as_padded_sequence:
boxes = list(
torch.nn.functional.pad(o, (len(o.shape) - 1) * [0, 0] + [0, -n]) for o, n in zip(boxes, self._N)
)
else:
boxes = boxes if self._is_batched else boxes.squeeze(0)
return boxes
[docs] def to_mask(self, height: int, width: int) -> torch.Tensor:
"""Convert 2D boxes to masks. Covered area is 1 and the remaining is 0.
Args:
height: height of the masked image/images.
width: width of the masked image/images.
Returns:
the output mask tensor, shape of :math:`(N, width, height)` or :math:`(B,N, width, height)` and dtype of
:func:`Boxes.dtype` (it can be any floating point dtype).
Note:
It is currently non-differentiable.
Examples:
>>> boxes = Boxes(torch.tensor([[ # Equivalent to boxes = Boxes.from_tensor([[1,1,4,3]])
... [1., 1.],
... [4., 1.],
... [4., 3.],
... [1., 3.],
... ]])) # 1x4x2
>>> boxes.to_mask(5, 5)
tensor([[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]]])
"""
if self._data.requires_grad:
raise RuntimeError(
"Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
)
if self._is_batched: # (B, N, 4, 2)
mask = torch.zeros(
(self._data.shape[0], self._data.shape[1], height, width), dtype=self.dtype, device=self.device
)
else: # (N, 4, 2)
mask = torch.zeros((self._data.shape[0], height, width), dtype=self.dtype, device=self.device)
# Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
clipped_boxes_xyxy = cast(torch.Tensor, self.to_tensor("xyxy", as_padded_sequence=True))
clipped_boxes_xyxy[..., ::2].clamp_(0, width)
clipped_boxes_xyxy[..., 1::2].clamp_(0, height)
# Reshape mask to (BxN, H, W) and boxes to (BxN, 4) to iterate over all of them.
# Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
for mask_channel, box_xyxy in zip(mask.view(-1, height, width), clipped_boxes_xyxy.view(-1, 4).round().int()):
# Mask channel dimensions: (height, width)
mask_channel[box_xyxy[1] : box_xyxy[3], box_xyxy[0] : box_xyxy[2]] = 1
return mask
[docs] def translate(self, size: Tensor, method: str = "warp", inplace: bool = False) -> "Boxes":
"""Translates boxes by the provided size.
Args:
size: translate size for x, y direction, shape of :math:`(B, 2)`.
method: "warp" or "fast".
inplace: do transform in-place and return self.
Returns:
The transformed boxes.
"""
if method == "fast":
raise NotImplementedError
elif method == "warp":
pass
else:
raise NotImplementedError
M: Tensor = eye_like(3, size)
M[:, :2, 2] = size
return self.transform_boxes(M, inplace=inplace)
@property
def data(self) -> torch.Tensor:
return self._data
@property
def mode(self) -> str:
return self._mode
@property
def device(self) -> torch.device:
"""Returns boxes device."""
return self._data.device
@property
def dtype(self) -> torch.dtype:
"""Returns boxes dtype."""
return self._data.dtype
[docs] def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "Boxes":
"""Like :func:`torch.nn.Module.to()` method."""
# In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
if dtype is not None and not _is_floating_point_dtype(dtype):
raise ValueError("Boxes must be in floating point")
self._data = self._data.to(device=device, dtype=dtype)
return self
def clone(self) -> "Boxes":
return Boxes(self._data.clone(), False)
[docs]@torch.jit.script
class Boxes3D:
r"""3D boxes containing N or BxN boxes.
Args:
boxes: 3D boxes, shape of :math:`(N,8,3)` or :math:`(B,N,8,3)`. See below for more details.
raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a floating
point tensor. True to raise an error when `boxes` isn't a floating point tensor, False to cast to float.
Note:
**3D boxes format** is defined as a floating data type tensor of shape ``Nx8x3`` or ``BxNx8x3`` where each box
is a `hexahedron <https://en.wikipedia.org/wiki/Hexahedron>`_ defined by it's 8 vertices coordinates.
Coordinates must be in ``x, y, z`` order. The height, width and depth of a box is defined as
``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``. Examples of
`hexahedrons <https://en.wikipedia.org/wiki/Hexahedron>`_ are cubes and rhombohedrons.
"""
def __init__(
self, boxes: torch.Tensor, raise_if_not_floating_point: bool = True, mode: str = "xyzxyz_plus"
) -> None:
if not isinstance(boxes, torch.Tensor):
raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
if not boxes.is_floating_point():
if raise_if_not_floating_point:
raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}.")
boxes = boxes.float()
if len(boxes.shape) == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
boxes = boxes.reshape((-1, 6))
if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (8, 3)):
raise ValueError(f"3D bbox shape must be (N, 8, 3) or (B, N, 8, 3). Got {boxes.shape}.")
self._is_batched = False if boxes.ndim == 3 else True
self._data = boxes
self._mode = mode
[docs] def get_boxes_shape(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Compute boxes heights and widths.
Returns:
- Boxes depths, shape of :math:`(N,)` or :math:`(B,N)`.
- Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
- Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
Example:
>>> boxes_xyzxyz = torch.tensor([[ 0, 1, 2, 10, 21, 32], [3, 4, 5, 43, 54, 65]])
>>> boxes3d = Boxes3D.from_tensor(boxes_xyzxyz)
>>> boxes3d.get_boxes_shape()
(tensor([30., 60.]), tensor([20., 50.]), tensor([10., 40.]))
"""
boxes_xyzwhd = self.to_tensor(mode='xyzwhd')
widths, heights, depths = boxes_xyzwhd[..., 3], boxes_xyzwhd[..., 4], boxes_xyzwhd[..., 5]
return depths, heights, widths
[docs] @classmethod
def from_tensor(cls, boxes: torch.Tensor, mode: str = "xyzxyz", validate_boxes: bool = True) -> "Boxes3D":
r"""Helper method to easily create :class:`Boxes3D` from 3D boxes stored in another format.
Args:
boxes: 3D boxes, shape of :math:`(N,6)` or :math:`(B,N,6)`.
mode: The format in which the 3D boxes are provided.
* 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
* 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
* 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width, height
and depth >= 1 (>= 2 when mode ends with '_plus' suffix).
Returns:
:class:`Boxes3D` class containing the original `boxes` in the format specified by ``mode``.
Examples:
>>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
>>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
>>> boxes.data # (2, 8, 3)
tensor([[[0., 3., 6.],
[0., 3., 6.],
[0., 3., 6.],
[0., 3., 6.],
[0., 3., 7.],
[0., 3., 7.],
[0., 3., 7.],
[0., 3., 7.]],
<BLANKLINE>
[[5., 1., 3.],
[7., 1., 3.],
[7., 3., 3.],
[5., 3., 3.],
[5., 1., 8.],
[7., 1., 8.],
[7., 3., 8.],
[5., 3., 8.]]])
"""
if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 6):
raise ValueError(f"BBox shape must be (N, 6) or (B, N, 6). Got {boxes.shape}.")
batched = boxes.ndim == 3
boxes = boxes if batched else boxes.unsqueeze(0)
boxes = boxes if boxes.is_floating_point() else boxes.float()
xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
mode = mode.lower()
if mode == "xyzxyz":
width = boxes[..., 3] - boxes[..., 0]
height = boxes[..., 4] - boxes[..., 1]
depth = boxes[..., 5] - boxes[..., 2]
elif mode == "xyzxyz_plus":
width = boxes[..., 3] - boxes[..., 0] + 1
height = boxes[..., 4] - boxes[..., 1] + 1
depth = boxes[..., 5] - boxes[..., 2] + 1
elif mode == "xyzwhd":
depth, height, width = boxes[..., 4], boxes[..., 3], boxes[..., 5]
else:
raise ValueError(f"Unknown mode {mode}")
if validate_boxes:
if (width <= 0).any():
raise ValueError("Some boxes have negative widths or 0.")
if (height <= 0).any():
raise ValueError("Some boxes have negative heights or 0.")
if (depth <= 0).any():
raise ValueError("Some boxes have negative depths or 0.")
hexahedrons = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
hexahedrons = hexahedrons if batched else hexahedrons.squeeze(0)
# Due to some torch.jit.script bug (at least <= 1.9), you need to pass all arguments to __init__ when
# constructing the class from inside of a method.
return cls(hexahedrons, raise_if_not_floating_point=False, mode=mode)
[docs] def to_tensor(self, mode: str = "xyzxyz") -> torch.Tensor:
r"""Cast :class:`Boxes3D` to a tensor. ``mode`` controls which 3D boxes format should be use to represent
boxes in the tensor.
Args:
mode: The format in which the boxes are provided.
* 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
* 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
* 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
* 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
*front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
back-top-right, back-bottom-right, back-bottom-left*. Vertices coordinates are in (x,y, z) order.
Finally, box width, height and depth are defined as ``width = xmax - xmin``, ``height = ymax - ymin``
and ``depth = zmax - zmin``.
* 'vertices_plus': similar to 'vertices' mode but where box width, length and depth are defined as
``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
Returns:
3D Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
* 'vertices' or 'verticies_plus': :math:`(N, 8, 3)` or :math:`(B, N, 8, 3)`.
* Any other value: :math:`(N, 6)` or :math:`(B, N, 6)`.
Note:
It is currently non-differentiable due to a bug. See github issue
`#1304 <https://github.com/kornia/kornia/issues/1396>`_.
Examples:
>>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
>>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
>>> assert (boxes.to_tensor(mode='xyzxyz') == boxes_xyzxyz).all()
"""
if self._data.requires_grad:
raise RuntimeError(
"Boxes3D.to_tensor doesn't support computing gradients since they aren't accurate. "
"Please, create boxes from tensors with `requires_grad=False`. "
"This is a known bug. Help is needed to fix it. For more information, "
"see https://github.com/kornia/kornia/issues/1396."
)
batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
# Create boxes in xyzxyz_plus format.
boxes = torch.stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
batched_boxes.shape[0], batched_boxes.shape[1], 6
)
mode = mode.lower()
if mode in ("xyzxyz", "xyzxyz_plus"):
pass
elif mode in ("xyzwhd", "vertices", "vertices_plus"):
width = boxes[..., 3] - boxes[..., 0] + 1
height = boxes[..., 4] - boxes[..., 1] + 1
depth = boxes[..., 5] - boxes[..., 2] + 1
boxes[..., 3] = width
boxes[..., 4] = height
boxes[..., 5] = depth
else:
raise ValueError(f"Unknown mode {mode}")
if mode in ("xyzxyz", "vertices"):
offset = torch.as_tensor([0, 0, 0, 1, 1, 1], device=boxes.device, dtype=boxes.dtype)
boxes = boxes + offset
if mode.startswith('vertices'):
xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
width, height, depth = boxes[..., 3], boxes[..., 4], boxes[..., 5]
boxes = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
boxes = boxes if self._is_batched else boxes.squeeze(0)
return boxes
[docs] def to_mask(self, depth: int, height: int, width: int) -> torch.Tensor:
"""Convert ·D boxes to masks. Covered area is 1 and the remaining is 0.
Args:
depth: depth of the masked image/images.
height: height of the masked image/images.
width: width of the masked image/images.
Returns:
the output mask tensor, shape of :math:`(N, depth, width, height)` or :math:`(B,N, depth, width, height)`
and dtype of :func:`Boxes3D.dtype` (it can be any floating point dtype).
Note:
It is currently non-differentiable.
Examples:
>>> boxes = Boxes3D(torch.tensor([[ # Equivalent to boxes = Boxes.3Dfrom_tensor([[1,1,1,3,3,2]])
... [1., 1., 1.],
... [3., 1., 1.],
... [3., 3., 1.],
... [1., 3., 1.],
... [1., 1., 2.],
... [3., 1., 2.],
... [3., 3., 2.],
... [1., 3., 2.],
... ]])) # 1x8x3
>>> boxes.to_mask(4, 5, 5)
tensor([[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]])
"""
if self._data.requires_grad:
raise RuntimeError(
"Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
)
if self._is_batched: # (B, N, 8, 3)
mask = torch.zeros(
(self._data.shape[0], self._data.shape[1], depth, height, width),
dtype=self._data.dtype,
device=self._data.device,
)
else: # (N, 8, 3)
mask = torch.zeros(
(self._data.shape[0], depth, height, width), dtype=self._data.dtype, device=self._data.device
)
# Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
clipped_boxes_xyzxyz = self.to_tensor("xyzxyz")
clipped_boxes_xyzxyz[..., ::3].clamp_(0, width)
clipped_boxes_xyzxyz[..., 1::3].clamp_(0, height)
clipped_boxes_xyzxyz[..., 2::3].clamp_(0, depth)
# Reshape mask to (BxN, D, H, W) and boxes to (BxN, 6) to iterate over all of them.
# Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
for mask_channel, box_xyzxyz in zip(
mask.view(-1, depth, height, width), clipped_boxes_xyzxyz.view(-1, 6).round().int()
):
# Mask channel dimensions: (depth, height, width)
mask_channel[
box_xyzxyz[2] : box_xyzxyz[5], box_xyzxyz[1] : box_xyzxyz[4], box_xyzxyz[0] : box_xyzxyz[3]
] = 1
return mask
@property
def data(self) -> torch.Tensor:
return self._data
@property
def mode(self) -> str:
return self._mode
@property
def device(self) -> torch.device:
"""Returns boxes device."""
return self._data.device
@property
def dtype(self) -> torch.dtype:
"""Returns boxes dtype."""
return self._data.dtype
[docs] def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "Boxes3D":
"""Like :func:`torch.nn.Module.to()` method."""
# In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
if dtype is not None and not _is_floating_point_dtype(dtype):
raise ValueError("Boxes must be in floating point")
self._data = self._data.to(device=device, dtype=dtype)
return self