Source code for kornia.contrib.edge_detection

from typing import List

from kornia.core import Module, Tensor
from kornia.filters.dexined import DexiNed
from kornia.testing import KORNIA_CHECK_SHAPE


[docs]class EdgeDetector(Module): r"""Detect edges in a given image using a CNN. By default, it uses the method described in :cite:`xsoria2020dexined`. Return: A tensor of shape :math:`(B,1,H,W)`. Example: >>> img = torch.rand(1, 3, 320, 320) >>> detect = EdgeDetector() >>> out = detect(img) >>> out.shape torch.Size([1, 1, 320, 320]) """ def __init__(self) -> None: super().__init__() self.model = DexiNed(pretrained=True) def load(self, path_file: str) -> None: self.model.load_from_file(path_file) def preprocess(self, image: Tensor) -> Tensor: return image def postprocess(self, data: List[Tensor]) -> Tensor: # input are intermediate layer -- for inference we need only last. return data[-1] # Bx1xHxW def forward(self, image: Tensor) -> Tensor: KORNIA_CHECK_SHAPE(image, ["B", "3", "H", "W"]) img = self.preprocess(image) out = self.model(img) return self.postprocess(out)