unet

A simple U-Net with timm backbone encoder.

Based off an old version of Unet in

https://github.com/qubvel/segmentation_models.pytorch

Hacked together by Ross Wightman

https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31

class fgvc.special.unet.Conv2dBnAct(in_channels, out_channels, kernel_size, padding=0, stride=1, act_layer=<class 'torch.nn.modules.activation.ReLU'>, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)
forward(x)

Run forward pass.

class fgvc.special.unet.DecoderBlock(in_channels, out_channels, scale_factor=2.0, act_layer=<class 'torch.nn.modules.activation.ReLU'>, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)
forward(x, skip: Tensor | None = None)

Run forward pass.

class fgvc.special.unet.Unet(encoder, num_classes, decoder_use_batchnorm=True, decoder_channels=(256, 128, 64, 32, 16), center=False, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, scale_factors=None)

Unet is a fully convolutional neural network for image semantic segmentation.

NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch

Parameters:
  • encoder – Classification model (without last dense layers) used as feature extractor to build segmentation model.

  • num_classes – Number of classes for output (output shape - (batch, classes, h, w)).

  • decoder_channels – List of numbers of Conv2D layer filters in decoder blocks

  • decoder_use_batchnorm – If True, use BatchNormalisation layer between Conv2D and Activation layers.

  • center – If True, add Conv2dReLU block on encoder head.

forward(x: Tensor)

Run forward pass.

class fgvc.special.unet.UnetDecoder(encoder_channels, decoder_channels=(256, 128, 64, 32, 16), final_channels=1, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, center=False, scale_factors=None)
forward(x: List[Tensor])

Run forward pass.

fgvc.special.unet.create_unet(arch_name: str, num_classes: int, pretrained: bool = True, **kwargs) Module

Create U-Net model with backbone from timm library.

Parameters:
  • arch_name – Architecture name of a backbone.

  • num_classes – Number of classes the U-Net predicts.

  • pretrained – If True, use pretrained checkpoint.

Returns:

PyTorch instance of a neural network.

Return type:

model