ClassificationTrainer
The page documents the properties and of a ClassificationTrainer:
from fgvc.core.training import ClassificationTrainer
trainer = ClassificationTrainer()
- class fgvc.core.training.ClassificationTrainer(model: Module, trainloader: DataLoader, criterion: Module, optimizer: Optimizer, *, validloader: DataLoader | None = None, scheduler: ReduceLROnPlateau | CosineLRScheduler | CosineAnnealingLR | None = None, accumulation_steps: int = 1, clip_grad: float | None = None, device: device | None = None, train_scores_fn: Callable | None = None, valid_scores_fn: Callable | None = None, wandb_train_prefix: str = 'Train. ', wandb_valid_prefix: str = 'Val. ', mixup: float | None = None, cutmix: float | None = None, mixup_prob: float | None = None, apply_ema: bool = False, ema_start_epoch: int = 0, ema_decay: float = 0.9999, **kwargs)
Class to perform training of a classification neural network and/or run inference.
- Parameters:
model – Pytorch neural network.
trainloader – Pytorch dataloader with training data.
criterion – Loss function.
optimizer – Optimizer algorithm.
validloader – Pytorch dataloader with validation data.
scheduler – Scheduler algorithm.
accumulation_steps – Number of iterations to accumulate gradients before performing optimizer step.
clip_grad – Max norm of the gradients for the gradient clipping.
device – Device to use (cpu,0,1,2,…).
train_scores_fn – Function for evaluating scores on the training data.
valid_scores_fn – Function for evaluating scores on the validation data.
wandb_train_prefix – Prefix string to include in the name of training scores logged to W&B.
wandb_valid_prefix – Prefix string to include in the name of validations scores logged to W&B.
mixup – Mixup alpha value, mixup is active if > 0.
cutmix – Cutmix alpha value, cutmix is active if > 0.
mixup_prob – Probability of applying mixup or cutmix per batch.
apply_ema – Apply EMA model weight averaging if true.
ema_start_epoch – Epoch number when to start model averaging.
ema_decay – Model weight decay.
- apply_mixup(imgs: Tensor, targs: Tensor) Tuple[Tensor, Tensor]
Apply mixup or cutmix method if arguments mixup or cutmix were used in Trainer.
- create_ema_model()
Initialize EMA averaged model.
- get_ema_model()
Get EMA averaged model.
- make_ema_update(epoch: int)
Update weights of the EMA averaged model.
- make_scheduler_step(epoch: int | None = None, *, valid_loss: float | None = None)
Make scheduler step after training one epoch.
The method uses different arguments depending on the scheduler type.
- Parameters:
epoch – Current epoch number. The method expects start index 1 (instead of 0).
valid_loss – Average validation loss to use for ReduceLROnPlateau scheduler.
- make_timm_scheduler_update(num_updates: int)
Make scheduler step update after training one iteration.
This is specific to timm schedulers.
- Parameters:
num_updates – Iteration number.
- predict(dataloader: DataLoader, return_preds: bool = True, *, model: Module | None = None) PredictOutput
Run inference.
- Parameters:
dataloader – PyTorch dataloader with validation/test data.
return_preds – If True, the method returns predictions and ground-truth targets.
model – Alternative PyTorch model to use for prediction like EMA model.
- Returns:
PredictOutput tuple with predictions, ground-truth targets,
average loss, and average scores.
- predict_batch(batch: tuple, *, model: Module | None = None) BatchOutput
Run a prediction iteration on one batch.
- Parameters:
batch – Tuple of arbitrary size with image and target pytorch tensors and optionally additional items depending on the dataloaders.
model – Alternative PyTorch model to use for prediction like EMA model.
- Return type:
BatchOutput tuple with predictions, ground-truth targets, and average loss.
- train(num_epochs: int = 1, seed: int = 777, path: str | None = None, resume: bool = False)
Train neural network.
- Parameters:
num_epochs – Number of epochs to train.
seed – Random seed to set.
path – Experiment path for saving training outputs like checkpoints or logs.
resume – If True resumes run from a checkpoint with optimizer and scheduler state.
- train_batch(batch: tuple) BatchOutput
Run a training iteration on one batch.
- Parameters:
batch – Tuple of arbitrary size with image and target pytorch tensors and optionally additional items depending on the dataloaders.
- Return type:
BatchOutput tuple with predictions, ground-truth targets, and average loss.
- train_epoch(epoch: int, dataloader: DataLoader) TrainEpochOutput
Train one epoch.
- Parameters:
epoch – Epoch number.
dataloader – PyTorch dataloader with training data.
- Return type:
TrainEpochOutput tuple with average loss and average scores.