metallic.trainer

class metallic.trainer.Trainer(metalearner: metallic.metalearners.base.MetaLearner, train_loader: metallic.data.dataloader.MetaDataLoader, val_loader: Optional[metallic.data.dataloader.MetaDataLoader] = None, n_epoches: int = 100, n_iters_per_epoch: int = 500, n_iters_test: int = 600, logger: Optional[metallic.utils.logger.Logger] = None)[source]

Bases: object

A wrap of training procedure for meta-learning algorithms.

Parameters
  • metalearner (MetaLearner) – An instance of MetaLearner class

  • train_loader (MetaDataLoader)) – Train data loader, an instance of MetaDataLoader class

  • val_loader (MetaDataLoader, optional) – Validation data loader, an instance of MetaDataLoader class

  • n_epoches (int, optional, default=100) – Number of epoches

  • n_iters_per_epoch (int, optional, default=500) – Number of the iterations per epoch

  • n_iters_test (int, optional, default=600) – Number of the iterations during meta-test stage

  • logger (Logger, optional) – An instance of Logger class

lr_schedule()[source]

Schedule learning rate.

run_epoch(epoch: int, train: bool = True)[source]

Train or evaluate an epoch.

run_train()[source]

Run training procedure.

save(is_best: bool = False)[source]

Save checkpoints.