r"""Implementation of "differentiable spatial to numerical" (soft-argmax) operations, as described in the paper
"Numerical Coordinate Regression with Convolutional Neural Networks" by Nibali et al."""
from typing import Tuple
import torch
import torch.nn.functional as F
from kornia.testing import check_is_tensor
from kornia.utils.grid import create_meshgrid
def _validate_batched_image_tensor_input(tensor):
check_is_tensor(tensor)
if not len(tensor.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {tensor.shape}")
[docs]def spatial_softmax2d(input: torch.Tensor, temperature: torch.Tensor = torch.tensor(1.0)) -> torch.Tensor:
r"""Apply the Softmax function over features in each image channel.
Note that this function behaves differently to :py:class:`torch.nn.Softmax2d`, which
instead applies Softmax over features at each spatial location.
Args:
input: the input tensor with shape :math:`(B, N, H, W)`.
temperature: factor to apply to input, adjusting the "smoothness" of the output distribution.
Returns:
a 2D probability distribution per image channel with shape :math:`(B, N, H, W)`.
Examples:
>>> heatmaps = torch.tensor([[[
... [0., 0., 0.],
... [0., 0., 0.],
... [0., 1., 2.]]]])
>>> spatial_softmax2d(heatmaps)
tensor([[[[0.0585, 0.0585, 0.0585],
[0.0585, 0.0585, 0.0585],
[0.0585, 0.1589, 0.4319]]]])
"""
_validate_batched_image_tensor_input(input)
batch_size, channels, height, width = input.shape
temperature = temperature.to(device=input.device, dtype=input.dtype)
x: torch.Tensor = input.view(batch_size, channels, -1)
x_soft: torch.Tensor = F.softmax(x * temperature, dim=-1)
return x_soft.view(batch_size, channels, height, width)
[docs]def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
r"""Compute the expectation of coordinate values using spatial probabilities.
The input heatmap is assumed to represent a valid spatial probability distribution,
which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
Args:
input: the input tensor representing dense spatial probabilities with shape :math:`(B, N, H, W)`.
normalized_coordinates: whether to return the coordinates normalized in the range
of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape.
Returns:
expected value of the 2D coordinates with shape :math:`(B, N, 2)`. Output order of the coordinates is (x, y).
Examples:
>>> heatmaps = torch.tensor([[[
... [0., 0., 0.],
... [0., 0., 0.],
... [0., 1., 0.]]]])
>>> spatial_expectation2d(heatmaps, False)
tensor([[[1., 2.]]])
"""
_validate_batched_image_tensor_input(input)
batch_size, channels, height, width = input.shape
# Create coordinates grid.
grid: torch.Tensor = create_meshgrid(height, width, normalized_coordinates, input.device)
grid = grid.to(input.dtype)
pos_x: torch.Tensor = grid[..., 0].reshape(-1)
pos_y: torch.Tensor = grid[..., 1].reshape(-1)
input_flat: torch.Tensor = input.view(batch_size, channels, -1)
# Compute the expectation of the coordinates.
expected_y: torch.Tensor = torch.sum(pos_y * input_flat, -1, keepdim=True)
expected_x: torch.Tensor = torch.sum(pos_x * input_flat, -1, keepdim=True)
output: torch.Tensor = torch.cat([expected_x, expected_y], -1)
return output.view(batch_size, channels, 2) # BxNx2
def _safe_zero_division(numerator: torch.Tensor, denominator: torch.Tensor, eps: float = 1e-32) -> torch.Tensor:
return numerator / torch.clamp(denominator, min=eps)
[docs]def render_gaussian2d(
mean: torch.Tensor, std: torch.Tensor, size: Tuple[int, int], normalized_coordinates: bool = True
):
r"""Render the PDF of a 2D Gaussian distribution.
Args:
mean: the mean location of the Gaussian to render, :math:`(\mu_x, \mu_y)`. Shape: :math:`(*, 2)`.
std: the standard deviation of the Gaussian to render, :math:`(\sigma_x, \sigma_y)`.
Shape :math:`(*, 2)`. Should be able to be broadcast with `mean`.
size: the (height, width) of the output image.
normalized_coordinates: whether ``mean`` and ``std`` are assumed to use coordinates normalized
in the range of :math:`[-1, 1]`. Otherwise, coordinates are assumed to be in the range of the output shape.
Returns:
tensor including rendered points with shape :math:`(*, H, W)`.
"""
if not (std.dtype == mean.dtype and std.device == mean.device):
raise TypeError("Expected inputs to have the same dtype and device")
height, width = size
# Create coordinates grid.
grid: torch.Tensor = create_meshgrid(height, width, normalized_coordinates, mean.device)
grid = grid.to(mean.dtype)
pos_x: torch.Tensor = grid[..., 0].view(height, width)
pos_y: torch.Tensor = grid[..., 1].view(height, width)
# Gaussian PDF = exp(-(x - \mu)^2 / (2 \sigma^2))
# = exp(dists * ks),
# where dists = (x - \mu)^2 and ks = -1 / (2 \sigma^2)
# dists <- (x - \mu)^2
dist_x = (pos_x - mean[..., 0, None, None]) ** 2
dist_y = (pos_y - mean[..., 1, None, None]) ** 2
# ks <- -1 / (2 \sigma^2)
k_x = -0.5 * torch.reciprocal(std[..., 0, None, None])
k_y = -0.5 * torch.reciprocal(std[..., 1, None, None])
# Assemble the 2D Gaussian.
exps_x = torch.exp(dist_x * k_x)
exps_y = torch.exp(dist_y * k_y)
gauss = exps_x * exps_y
# Rescale so that values sum to one.
val_sum = gauss.sum(-2, keepdim=True).sum(-1, keepdim=True)
gauss = _safe_zero_division(gauss, val_sum)
return gauss