Source code for kornia.contrib.classification
"""Module containing utilities for classification."""
import torch
from torch import nn
[docs]class ClassificationHead(nn.Module):
"""Module to be used as a classification head.
Args:
embed_size: the logits tensor coming from the networks.
num_classes: an integer representing the numbers of classes to classify.
Example:
>>> feat = torch.rand(1, 256, 256)
>>> head = ClassificationHead(256, 10)
>>> head(feat).shape
torch.Size([1, 10])
"""
def __init__(self, embed_size: int = 768, num_classes: int = 10) -> None:
super().__init__()
self.norm = nn.LayerNorm(embed_size)
self.linear = nn.Linear(embed_size, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x.mean(-2)
return self.linear(self.norm(out))