Models

fgvc.core.models.get_model(architecture_name: str, target_size: int | None = None, pretrained: bool = False, *, checkpoint_path: str | BytesIO | None = None, strict: bool = True) Module

Get a timm model.

Parameters:
  • architecture_name – Name of the network architecture from timm library.

  • target_size – Output feature size of the new prediction head.

  • pretrained – If true load pretrained weights from timm library.

  • checkpoint_path – Path (or IO Buffer) with checkpoint weights to load after the model is initialized.

  • strict – Whether to strictly enforce the keys in state_dict to match between the model and checkpoint weights from file. Used when argument checkpoint_path is specified.

Returns:

PyTorch model from timm library.

Return type:

model

fgvc.core.models.get_model_target_size(model: Module) int | None

Get target size (number of output classes) of a timm model.

Parameters:

model – PyTorch model from timm library.

Returns:

Output feature size of a prediction head.

Return type:

target_size

fgvc.core.models.set_prediction_head(model: Module, target_size: int, *, in_features: int | None = None)

Replace prediction head of a timm model.

Parameters:
  • model – PyTorch model from timm library.

  • target_size – Output feature size of the new prediction head.

  • in_features – Number of input features for the prediction head. The parameter is needed in special cases, e.g., when the current prediction head is nn.Identity.

Returns:

The input timm model with new prediction head.

Return type:

model