Source code for kornia.filters.dexined

# adapted from: https://github.com/xavysp/DexiNed/blob/d944b70eb6eaf40e22f8467c1e12919aa600d8e4/model.py
from collections import OrderedDict
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.core import Module, Tensor, concatenate
from kornia.utils import map_location_to_cpu

url: str = "http://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth"


def weight_init(m):
    if isinstance(m, (nn.Conv2d,)):
        # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
        torch.nn.init.xavier_normal_(m.weight, gain=1.0)
        # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, mean=0.0)

        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

    # for fusion layer
    if isinstance(m, (nn.ConvTranspose2d,)):
        # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
        torch.nn.init.xavier_normal_(m.weight, gain=1.0)
        # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)

        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, std=0.1)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


class CoFusion(Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, out_ch, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.norm_layer1 = nn.GroupNorm(4, 64)
        self.norm_layer2 = nn.GroupNorm(4, 64)

    def forward(self, x: Tensor) -> Tensor:
        # fusecat = torch.cat(x, dim=1)
        attn = self.relu(self.norm_layer1(self.conv1(x)))
        attn = self.relu(self.norm_layer2(self.conv2(attn)))
        attn = F.softmax(self.conv3(attn), dim=1)

        # return ((fusecat * attn).sum(1)).unsqueeze(1)
        return ((x * attn).sum(1)).unsqueeze(1)


class _DenseLayer(nn.Sequential):
    def __init__(self, input_features, out_features):
        super().__init__(
            OrderedDict(
                [
                    ('relu1', nn.ReLU(inplace=True)),
                    ('conv1', nn.Conv2d(input_features, out_features, kernel_size=3, stride=1, padding=2, bias=True)),
                    ('norm1', nn.BatchNorm2d(out_features)),
                    ('relu2', nn.ReLU(inplace=True)),
                    ('conv2', nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True)),
                    ('norm2', nn.BatchNorm2d(out_features)),
                ]
            )
        )

    def forward(self, x: List[Tensor]) -> List[Tensor]:
        x1, x2 = x[0], x[1]
        x3: Tensor = x1
        for mod in self:
            x3 = mod(x3)
        return [0.5 * (x3 + x2), x2]


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, input_features, out_features):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_features, out_features)
            self.add_module('denselayer%d' % (i + 1), layer)
            input_features = out_features

    def forward(self, x: List[Tensor]) -> List[Tensor]:
        x_out = x
        for mod in self:
            x_out = mod(x_out)
        return x_out


class UpConvBlock(Module):
    def __init__(self, in_features, up_scale):
        super().__init__()
        self.up_factor = 2
        self.constant_features = 16

        layers = self.make_deconv_layers(in_features, up_scale)
        if layers is None:
            raise Exception("layers cannot be none")
        self.features = nn.Sequential(*layers)

    def make_deconv_layers(self, in_features: int, up_scale: int) -> List[Module]:
        layers: List[Module] = []
        all_pads = [0, 0, 1, 3, 7]
        for i in range(up_scale):
            kernel_size = 2**up_scale
            pad = all_pads[up_scale]  # kernel_size-1
            out_features = self.compute_out_features(i, up_scale)
            layers.append(nn.Conv2d(in_features, out_features, 1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.ConvTranspose2d(out_features, out_features, kernel_size, stride=2, padding=pad))
            in_features = out_features
        return layers

    def compute_out_features(self, idx: int, up_scale: int):
        return 1 if idx == up_scale - 1 else self.constant_features

    def forward(self, x: Tensor, out_shape: List[int]) -> Tensor:
        out = self.features(x)
        if out.shape[-2:] != out_shape:
            out = F.interpolate(out, out_shape, mode='bilinear')
        return out


class SingleConvBlock(Module):
    def __init__(self, in_features, out_features, stride, use_bs=True):
        super().__init__()
        self.use_bn = use_bs
        self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        return x


class DoubleConvBlock(nn.Sequential):
    def __init__(self, in_features, mid_features, out_features=None, stride=1, use_act=True):
        super().__init__()
        if out_features is None:
            out_features = mid_features
        self.add_module("conv1", nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride))
        self.add_module("bn1", nn.BatchNorm2d(mid_features))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(mid_features, out_features, 3, padding=1))
        self.add_module("bn2", nn.BatchNorm2d(out_features))
        if use_act:
            self.add_module("relu2", nn.ReLU(inplace=True))


[docs]class DexiNed(Module): r"""Definition of the DXtrem network from :cite:`xsoria2020dexined`. Return: A list of tensor with the intermediate features which the last element is the edges map with shape :math:`(B,1,H,W)`. Example: >>> img = torch.rand(1, 3, 320, 320) >>> net = DexiNed(pretrained=False) >>> out = net(img) >>> out[-1].shape torch.Size([1, 1, 320, 320]) """ def __init__(self, pretrained: bool): super().__init__() self.block_1 = DoubleConvBlock(3, 32, 64, stride=2) self.block_2 = DoubleConvBlock(64, 128, use_act=False) self.dblock_3 = _DenseBlock(2, 128, 256) # [128,256,100,100] self.dblock_4 = _DenseBlock(3, 256, 512) self.dblock_5 = _DenseBlock(3, 512, 512) self.dblock_6 = _DenseBlock(3, 512, 256) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # left skip connections, figure in Journal self.side_1 = SingleConvBlock(64, 128, 2) self.side_2 = SingleConvBlock(128, 256, 2) self.side_3 = SingleConvBlock(256, 512, 2) self.side_4 = SingleConvBlock(512, 512, 1) self.side_5 = SingleConvBlock(512, 256, 1) # right skip connections, figure in Journal paper self.pre_dense_2 = SingleConvBlock(128, 256, 2) self.pre_dense_3 = SingleConvBlock(128, 256, 1) self.pre_dense_4 = SingleConvBlock(256, 512, 1) self.pre_dense_5 = SingleConvBlock(512, 512, 1) self.pre_dense_6 = SingleConvBlock(512, 256, 1) # USNet self.up_block_1 = UpConvBlock(64, 1) self.up_block_2 = UpConvBlock(128, 1) self.up_block_3 = UpConvBlock(256, 2) self.up_block_4 = UpConvBlock(512, 3) self.up_block_5 = UpConvBlock(512, 4) self.up_block_6 = UpConvBlock(256, 4) self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False) # hed fusion method # self.block_cat = CoFusion(6,6)# cats fusion method if pretrained: self.load_from_file(url) else: self.apply(weight_init) def load_from_file(self, path_file: str): # use torch.hub to load pretrained model pretrained_dict = torch.hub.load_state_dict_from_url(path_file, map_location=map_location_to_cpu) self.load_state_dict(pretrained_dict, strict=True) self.eval() def forward(self, x: Tensor) -> List[Tensor]: # Block 1 block_1 = self.block_1(x) block_1_side = self.side_1(block_1) # Block 2 block_2 = self.block_2(block_1) block_2_down = self.maxpool(block_2) block_2_add = block_2_down + block_1_side block_2_side = self.side_2(block_2_add) # Block 3 block_3_pre_dense = self.pre_dense_3(block_2_down) block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) block_3_down = self.maxpool(block_3) # [128,256,50,50] block_3_add = block_3_down + block_2_side block_3_side = self.side_3(block_3_add) # Block 4 block_2_resize_half = self.pre_dense_2(block_2_down) block_4_pre_dense = self.pre_dense_4(block_3_down + block_2_resize_half) block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense]) block_4_down = self.maxpool(block_4) block_4_add = block_4_down + block_3_side block_4_side = self.side_4(block_4_add) # Block 5 block_5_pre_dense = self.pre_dense_5(block_4_down) # block_5_pre_dense_512 +block_4_down block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense]) block_5_add = block_5 + block_4_side # Block 6 block_6_pre_dense = self.pre_dense_6(block_5) block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense]) # upsampling blocks out_shape = x.shape[-2:] out_1 = self.up_block_1(block_1, out_shape) out_2 = self.up_block_2(block_2, out_shape) out_3 = self.up_block_3(block_3, out_shape) out_4 = self.up_block_4(block_4, out_shape) out_5 = self.up_block_5(block_5, out_shape) out_6 = self.up_block_6(block_6, out_shape) results = [out_1, out_2, out_3, out_4, out_5, out_6] # concatenate multiscale outputs block_cat = concatenate(results, 1) # Bx6xHxW block_cat = self.block_cat(block_cat) # Bx1xHxW # return results results.append(block_cat) return results