Source code for kornia.filters.median
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kernels import get_binary_kernel2d
def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:
r"""Utility function that computes zero padding tuple."""
computed: List[int] = [(k - 1) // 2 for k in kernel_size]
return computed[0], computed[1]
[docs]def median_blur(input: torch.Tensor, kernel_size: Tuple[int, int]) -> torch.Tensor:
r"""Blur an image using the median filter.
.. image:: _static/img/median_blur.png
Args:
input: the input image with shape :math:`(B,C,H,W)`.
kernel_size: the blurring kernel size.
Returns:
the blurred input tensor with shape :math:`(B,C,H,W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
filtering_operators.html>`__.
Example:
>>> input = torch.rand(2, 4, 5, 7)
>>> output = median_blur(input, (3, 3))
>>> output.shape
torch.Size([2, 4, 5, 7])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
padding: Tuple[int, int] = _compute_zero_padding(kernel_size)
# prepare kernel
kernel: torch.Tensor = get_binary_kernel2d(kernel_size).to(input)
b, c, h, w = input.shape
# map the local window to single vector
features: torch.Tensor = F.conv2d(input.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1)
features = features.view(b, c, -1, h, w) # BxCx(K_h * K_w)xHxW
# compute the median along the feature axis
median: torch.Tensor = torch.median(features, dim=2)[0]
return median
[docs]class MedianBlur(nn.Module):
r"""Blur an image using the median filter.
Args:
kernel_size: the blurring kernel size.
Returns:
the blurred input tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Example:
>>> input = torch.rand(2, 4, 5, 7)
>>> blur = MedianBlur((3, 3))
>>> output = blur(input)
>>> output.shape
torch.Size([2, 4, 5, 7])
"""
def __init__(self, kernel_size: Tuple[int, int]) -> None:
super().__init__()
self.kernel_size: Tuple[int, int] = kernel_size
def forward(self, input: torch.Tensor) -> torch.Tensor:
return median_blur(input, self.kernel_size)