unet
A simple U-Net with timm backbone encoder.
- Based off an old version of Unet in
- Hacked together by Ross Wightman
https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31
- class fgvc.special.unet.Conv2dBnAct(in_channels, out_channels, kernel_size, padding=0, stride=1, act_layer=<class 'torch.nn.modules.activation.ReLU'>, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)
- forward(x)
Run forward pass.
- class fgvc.special.unet.DecoderBlock(in_channels, out_channels, scale_factor=2.0, act_layer=<class 'torch.nn.modules.activation.ReLU'>, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)
- forward(x, skip: Tensor | None = None)
Run forward pass.
- class fgvc.special.unet.Unet(encoder, num_classes, decoder_use_batchnorm=True, decoder_channels=(256, 128, 64, 32, 16), center=False, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, scale_factors=None)
Unet is a fully convolutional neural network for image semantic segmentation.
NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
- Parameters:
encoder – Classification model (without last dense layers) used as feature extractor to build segmentation model.
num_classes – Number of classes for output (output shape - (batch, classes, h, w)).
decoder_channels – List of numbers of
Conv2D
layer filters in decoder blocksdecoder_use_batchnorm – If True, use BatchNormalisation layer between Conv2D and Activation layers.
center – If True, add Conv2dReLU block on encoder head.
- forward(x: Tensor)
Run forward pass.
- class fgvc.special.unet.UnetDecoder(encoder_channels, decoder_channels=(256, 128, 64, 32, 16), final_channels=1, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, center=False, scale_factors=None)
- forward(x: List[Tensor])
Run forward pass.
- fgvc.special.unet.create_unet(arch_name: str, num_classes: int, pretrained: bool = True, **kwargs) Module
Create U-Net model with backbone from timm library.
- Parameters:
arch_name – Architecture name of a backbone.
num_classes – Number of classes the U-Net predicts.
pretrained – If True, use pretrained checkpoint.
- Returns:
PyTorch instance of a neural network.
- Return type:
model