Models
- fgvc.core.models.get_model(architecture_name: str, target_size: int | None = None, pretrained: bool = False, *, checkpoint_path: str | BytesIO | None = None, strict: bool = True) Module
Get a timm model.
- Parameters:
architecture_name – Name of the network architecture from timm library.
target_size – Output feature size of the new prediction head.
pretrained – If true load pretrained weights from timm library.
checkpoint_path – Path (or IO Buffer) with checkpoint weights to load after the model is initialized.
strict – Whether to strictly enforce the keys in state_dict to match between the model and checkpoint weights from file. Used when argument checkpoint_path is specified.
- Returns:
PyTorch model from timm library.
- Return type:
model
- fgvc.core.models.get_model_target_size(model: Module) int | None
Get target size (number of output classes) of a timm model.
- Parameters:
model – PyTorch model from timm library.
- Returns:
Output feature size of a prediction head.
- Return type:
target_size
- fgvc.core.models.set_prediction_head(model: Module, target_size: int, *, in_features: int | None = None)
Replace prediction head of a timm model.
- Parameters:
model – PyTorch model from timm library.
target_size – Output feature size of the new prediction head.
in_features – Number of input features for the prediction head. The parameter is needed in special cases, e.g., when the current prediction head is nn.Identity.
- Returns:
The input timm model with new prediction head.
- Return type:
model