import warnings
from typing import Optional, Tuple
import torch
from .linalg import transform_points
__all__ = [
"validate_bbox",
"validate_bbox3d",
"infer_bbox_shape",
"infer_bbox_shape3d",
"bbox_to_mask",
"bbox_to_mask3d",
"bbox_generator",
"bbox_generator3d",
"transform_bbox",
"nms",
]
[docs]@torch.jit.ignore
def validate_bbox(boxes: torch.Tensor) -> bool:
"""Validate if a 2D bounding box usable or not. This function checks if the boxes are rectangular or not.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
bottom-left. The coordinates must be in the x, y order.
"""
if not (len(boxes.shape) in [3, 4] and boxes.shape[-2:] == torch.Size([4, 2])):
raise AssertionError(f"Box shape must be (B, 4, 2) or (B, N, 4, 2). Got {boxes.shape}.")
if len(boxes.shape) == 4:
boxes = boxes.view(-1, 4, 2)
if not torch.allclose((boxes[:, 1, 0] - boxes[:, 0, 0] + 1), (boxes[:, 2, 0] - boxes[:, 3, 0] + 1), atol=1e-4):
raise ValueError(
"Boxes must have be rectangular, while get widths %s and %s"
% (str(boxes[:, 1, 0] - boxes[:, 0, 0] + 1), str(boxes[:, 2, 0] - boxes[:, 3, 0] + 1))
)
if not torch.allclose((boxes[:, 2, 1] - boxes[:, 0, 1] + 1), (boxes[:, 3, 1] - boxes[:, 1, 1] + 1), atol=1e-4):
raise ValueError(
"Boxes must have be rectangular, while get heights %s and %s"
% (str(boxes[:, 2, 1] - boxes[:, 0, 1] + 1), str(boxes[:, 3, 1] - boxes[:, 1, 1] + 1))
)
return True
[docs]@torch.jit.ignore
def validate_bbox3d(boxes: torch.Tensor) -> bool:
"""Validate if a 3D bounding box usable or not. This function checks if the boxes are cube or not.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
The coordinates must be in the x, y, z order.
"""
if not (len(boxes.shape) in [3, 4] and boxes.shape[-2:] == torch.Size([8, 3])):
raise AssertionError(f"Box shape must be (B, 8, 3) or (B, N, 8, 3). Got {boxes.shape}.")
if len(boxes.shape) == 4:
boxes = boxes.view(-1, 8, 3)
left = torch.index_select(boxes, 1, torch.tensor([1, 2, 5, 6], device=boxes.device, dtype=torch.long))[:, :, 0]
right = torch.index_select(boxes, 1, torch.tensor([0, 3, 4, 7], device=boxes.device, dtype=torch.long))[:, :, 0]
widths = left - right + 1
if not torch.allclose(widths.permute(1, 0), widths[:, 0]):
raise AssertionError(f"Boxes must have be cube, while get different widths {widths}.")
bot = torch.index_select(boxes, 1, torch.tensor([2, 3, 6, 7], device=boxes.device, dtype=torch.long))[:, :, 1]
upper = torch.index_select(boxes, 1, torch.tensor([0, 1, 4, 5], device=boxes.device, dtype=torch.long))[:, :, 1]
heights = bot - upper + 1
if not torch.allclose(heights.permute(1, 0), heights[:, 0]):
raise AssertionError(f"Boxes must have be cube, while get different heights {heights}.")
depths = boxes[:, 4:, 2] - boxes[:, :4, 2] + 1
if not torch.allclose(depths.permute(1, 0), depths[:, 0]):
raise AssertionError(f"Boxes must have be cube, while get different depths {depths}.")
return True
[docs]def infer_bbox_shape(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Auto-infer the output sizes for the given 2D bounding boxes.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
bottom-left. The coordinates must be in the x, y order.
Returns:
- Bounding box heights, shape of :math:`(B,)`.
- Boundingbox widths, shape of :math:`(B,)`.
Example:
>>> boxes = torch.tensor([[
... [1., 1.],
... [2., 1.],
... [2., 2.],
... [1., 2.],
... ], [
... [1., 1.],
... [3., 1.],
... [3., 2.],
... [1., 2.],
... ]]) # 2x4x2
>>> infer_bbox_shape(boxes)
(tensor([2., 2.]), tensor([2., 3.]))
"""
validate_bbox(boxes)
width: torch.Tensor = boxes[:, 1, 0] - boxes[:, 0, 0] + 1
height: torch.Tensor = boxes[:, 2, 1] - boxes[:, 0, 1] + 1
return height, width
[docs]def infer_bbox_shape3d(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Auto-infer the output sizes for the given 3D bounding boxes.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
The coordinates must be in the x, y, z order.
Returns:
- Bounding box depths, shape of :math:`(B,)`.
- Bounding box heights, shape of :math:`(B,)`.
- Bounding box widths, shape of :math:`(B,)`.
Example:
>>> boxes = torch.tensor([[[ 0, 1, 2],
... [10, 1, 2],
... [10, 21, 2],
... [ 0, 21, 2],
... [ 0, 1, 32],
... [10, 1, 32],
... [10, 21, 32],
... [ 0, 21, 32]],
... [[ 3, 4, 5],
... [43, 4, 5],
... [43, 54, 5],
... [ 3, 54, 5],
... [ 3, 4, 65],
... [43, 4, 65],
... [43, 54, 65],
... [ 3, 54, 65]]]) # 2x8x3
>>> infer_bbox_shape3d(boxes)
(tensor([31, 61]), tensor([21, 51]), tensor([11, 41]))
"""
validate_bbox3d(boxes)
left = torch.index_select(boxes, 1, torch.tensor([1, 2, 5, 6], device=boxes.device, dtype=torch.long))[:, :, 0]
right = torch.index_select(boxes, 1, torch.tensor([0, 3, 4, 7], device=boxes.device, dtype=torch.long))[:, :, 0]
widths = (left - right + 1)[:, 0]
bot = torch.index_select(boxes, 1, torch.tensor([2, 3, 6, 7], device=boxes.device, dtype=torch.long))[:, :, 1]
upper = torch.index_select(boxes, 1, torch.tensor([0, 1, 4, 5], device=boxes.device, dtype=torch.long))[:, :, 1]
heights = (bot - upper + 1)[:, 0]
depths = (boxes[:, 4:, 2] - boxes[:, :4, 2] + 1)[:, 0]
return depths, heights, widths
[docs]def bbox_to_mask(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right
and bottom-left. The coordinates must be in the x, y order.
width: width of the masked image.
height: height of the masked image.
Returns:
the output mask tensor.
Note:
It is currently non-differentiable.
Examples:
>>> boxes = torch.tensor([[
... [1., 1.],
... [3., 1.],
... [3., 2.],
... [1., 2.],
... ]]) # 1x4x2
>>> bbox_to_mask(boxes, 5, 5)
tensor([[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
"""
validate_bbox(boxes)
# zero padding the surroudings
mask = torch.zeros((len(boxes), height + 2, width + 2), dtype=torch.float, device=boxes.device)
# push all points one pixel off
# in order to zero-out the fully filled rows or columns
box_i = (boxes + 1).long()
# set all pixels within box to 1
for msk, bx in zip(mask, box_i):
msk[bx[0, 1] : bx[2, 1] + 1, bx[0, 0] : bx[1, 0] + 1] = 1.0
return mask[:, 1:-1, 1:-1]
[docs]def bbox_to_mask3d(boxes: torch.Tensor, size: Tuple[int, int, int]) -> torch.Tensor:
"""Convert 3D bounding boxes to masks. Covered area is 1. and the remaining is 0.
Args:
boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
The coordinates must be in the x, y, z order.
size: depth, height and width of the masked image.
Returns:
the output mask tensor.
Examples:
>>> boxes = torch.tensor([[
... [1., 1., 1.],
... [2., 1., 1.],
... [2., 2., 1.],
... [1., 2., 1.],
... [1., 1., 2.],
... [2., 1., 2.],
... [2., 2., 2.],
... [1., 2., 2.],
... ]]) # 1x8x3
>>> bbox_to_mask3d(boxes, (4, 5, 5))
tensor([[[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]]])
"""
validate_bbox3d(boxes)
mask = torch.zeros((len(boxes), *size))
mask_out = []
# TODO: Looking for a vectorized way
for m, box in zip(mask, boxes):
m = m.index_fill(
0,
torch.arange(box[0, 2].item(), box[4, 2].item() + 1, device=box.device, dtype=torch.long),
torch.tensor(1, device=box.device, dtype=box.dtype),
)
m = m.index_fill(
1,
torch.arange(box[1, 1].item(), box[2, 1].item() + 1, device=box.device, dtype=torch.long),
torch.tensor(1, device=box.device, dtype=box.dtype),
)
m = m.index_fill(
2,
torch.arange(box[0, 0].item(), box[1, 0].item() + 1, device=box.device, dtype=torch.long),
torch.tensor(1, device=box.device, dtype=box.dtype),
)
m = m.unsqueeze(dim=0)
m_out = torch.ones_like(m)
m_out = m_out * (m == 1).all(dim=2, keepdim=True).all(dim=1, keepdim=True)
m_out = m_out * (m == 1).all(dim=3, keepdim=True).all(dim=1, keepdim=True)
m_out = m_out * (m == 1).all(dim=2, keepdim=True).all(dim=3, keepdim=True)
mask_out.append(m_out)
return torch.stack(mask_out, dim=0).float()
[docs]def bbox_generator(
x_start: torch.Tensor, y_start: torch.Tensor, width: torch.Tensor, height: torch.Tensor
) -> torch.Tensor:
"""Generate 2D bounding boxes according to the provided start coords, width and height.
Args:
x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
tensor or :math:`(B,)`.
y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
tensor or :math:`(B,)`.
width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
Returns:
the bounding box tensor.
Examples:
>>> x_start = torch.tensor([0, 1])
>>> y_start = torch.tensor([1, 0])
>>> width = torch.tensor([5, 3])
>>> height = torch.tensor([7, 4])
>>> bbox_generator(x_start, y_start, width, height)
tensor([[[0, 1],
[4, 1],
[4, 7],
[0, 7]],
<BLANKLINE>
[[1, 0],
[3, 0],
[3, 3],
[1, 3]]])
"""
if not (x_start.shape == y_start.shape and x_start.dim() in [0, 1]):
raise AssertionError(f"`x_start` and `y_start` must be a scalar or (B,). Got {x_start}, {y_start}.")
if not (width.shape == height.shape and width.dim() in [0, 1]):
raise AssertionError(f"`width` and `height` must be a scalar or (B,). Got {width}, {height}.")
if not x_start.dtype == y_start.dtype == width.dtype == height.dtype:
raise AssertionError(
"All tensors must be in the same dtype. Got "
f"`x_start`({x_start.dtype}), `y_start`({x_start.dtype}), `width`({width.dtype}), `height`({height.dtype})."
)
if not x_start.device == y_start.device == width.device == height.device:
raise AssertionError(
"All tensors must be in the same device. Got "
f"`x_start`({x_start.device}), `y_start`({x_start.device}), "
f"`width`({width.device}), `height`({height.device})."
)
bbox = torch.tensor([[[0, 0], [0, 0], [0, 0], [0, 0]]], device=x_start.device, dtype=x_start.dtype).repeat(
1 if x_start.dim() == 0 else len(x_start), 1, 1
)
bbox[:, :, 0] += x_start.view(-1, 1)
bbox[:, :, 1] += y_start.view(-1, 1)
bbox[:, 1, 0] += width - 1
bbox[:, 2, 0] += width - 1
bbox[:, 2, 1] += height - 1
bbox[:, 3, 1] += height - 1
return bbox
[docs]def bbox_generator3d(
x_start: torch.Tensor,
y_start: torch.Tensor,
z_start: torch.Tensor,
width: torch.Tensor,
height: torch.Tensor,
depth: torch.Tensor,
) -> torch.Tensor:
"""Generate 3D bounding boxes according to the provided start coords, width, height and depth.
Args:
x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
tensor or :math:`(B,)`.
y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
tensor or :math:`(B,)`.
z_start: a tensor containing the z coordinates of the bounding boxes to be extracted. Shape must be a scalar
tensor or :math:`(B,)`.
width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
depth: depths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
Returns:
the 3d bounding box tensor :math:`(B, 8, 3)`.
Examples:
>>> x_start = torch.tensor([0, 3])
>>> y_start = torch.tensor([1, 4])
>>> z_start = torch.tensor([2, 5])
>>> width = torch.tensor([10, 40])
>>> height = torch.tensor([20, 50])
>>> depth = torch.tensor([30, 60])
>>> bbox_generator3d(x_start, y_start, z_start, width, height, depth)
tensor([[[ 0, 1, 2],
[10, 1, 2],
[10, 21, 2],
[ 0, 21, 2],
[ 0, 1, 32],
[10, 1, 32],
[10, 21, 32],
[ 0, 21, 32]],
<BLANKLINE>
[[ 3, 4, 5],
[43, 4, 5],
[43, 54, 5],
[ 3, 54, 5],
[ 3, 4, 65],
[43, 4, 65],
[43, 54, 65],
[ 3, 54, 65]]])
"""
if not (x_start.shape == y_start.shape == z_start.shape and x_start.dim() in [0, 1]):
raise AssertionError(
f"`x_start`, `y_start` and `z_start` must be a scalar or (B,). Got {x_start}, {y_start}, {z_start}."
)
if not (width.shape == height.shape == depth.shape and width.dim() in [0, 1]):
raise AssertionError(f"`width`, `height` and `depth` must be a scalar or (B,). Got {width}, {height}, {depth}.")
if not x_start.dtype == y_start.dtype == z_start.dtype == width.dtype == height.dtype == depth.dtype:
raise AssertionError(
"All tensors must be in the same dtype. "
f"Got `x_start`({x_start.dtype}), `y_start`({x_start.dtype}), `z_start`({x_start.dtype}), "
f"`width`({width.dtype}), `height`({height.dtype}) and `depth`({depth.dtype})."
)
if not x_start.device == y_start.device == z_start.device == width.device == height.device == depth.device:
raise AssertionError(
"All tensors must be in the same device. "
f"Got `x_start`({x_start.device}), `y_start`({x_start.device}), `z_start`({x_start.device}), "
f"`width`({width.device}), `height`({height.device}) and `depth`({depth.device})."
)
# front
bbox = torch.tensor(
[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]], device=x_start.device, dtype=x_start.dtype
).repeat(len(x_start), 1, 1)
bbox[:, :, 0] += x_start.view(-1, 1)
bbox[:, :, 1] += y_start.view(-1, 1)
bbox[:, :, 2] += z_start.view(-1, 1)
bbox[:, 1, 0] += width
bbox[:, 2, 0] += width
bbox[:, 2, 1] += height
bbox[:, 3, 1] += height
# back
bbox_back = bbox.clone()
bbox_back[:, :, -1] += depth.unsqueeze(dim=1).repeat(1, 4)
bbox = torch.cat([bbox, bbox_back], dim=1)
return bbox
[docs]def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
"""Perform non-maxima suppression (NMS) on a given tensor of bounding boxes according to the intersection-over-
union (IoU).
Args:
boxes: tensor containing the encoded bounding boxes with the shape :math:`(N, (x_1, y_1, x_2, y_2))`.
scores: tensor containing the scores associated to each bounding box with shape :math:`(N,)`.
iou_threshold: the throshold to discard the overlapping boxes.
Return:
A tensor mask with the indices to keep from the input set of boxes and scores.
Example:
>>> boxes = torch.tensor([
... [10., 10., 20., 20.],
... [15., 5., 15., 25.],
... [100., 100., 200., 200.],
... [100., 100., 200., 200.]])
>>> scores = torch.tensor([0.9, 0.8, 0.7, 0.9])
>>> nms(boxes, scores, iou_threshold=0.8)
tensor([0, 3, 1])
"""
if len(boxes.shape) != 2 and boxes.shape[-1] != 4:
raise ValueError(f"boxes expected as Nx4. Got: {boxes.shape}.")
if len(scores.shape) != 1:
raise ValueError(f"scores expected as N. Got: {scores.shape}.")
if boxes.shape[0] != scores.shape[0]:
raise ValueError(f"boxes and scores mus have same shape. Got: {boxes.shape, scores.shape}.")
x1, y1, x2, y2 = boxes.unbind(-1)
areas = (x2 - x1) * (y2 - y1)
_, order = scores.sort(descending=True)
keep = []
while order.shape[0] > 0:
i = order[0]
keep.append(i)
xx1 = torch.max(x1[i], x1[order[1:]])
yy1 = torch.max(y1[i], y1[order[1:]])
xx2 = torch.min(x2[i], x2[order[1:]])
yy2 = torch.min(y2[i], y2[order[1:]])
w = torch.clamp(xx2 - xx1, min=0.0)
h = torch.clamp(yy2 - yy1, min=0.0)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = torch.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
if len(keep) > 0:
return torch.stack(keep)
return torch.tensor(keep)