import time
from tqdm import tqdm
from typing import Optional
from torch import nn
from ..data import MetaDataLoader
from ..metalearners import MetaLearner
from ..utils import MetricTracker, Logger
[docs]class Trainer:
"""
A wrap of training procedure for meta-learning algorithms.
Parameters
----------
metalearner : MetaLearner
An instance of :class:`~metallic.metalearners.MetaLearner` class
train_loader : MetaDataLoader)
Train data loader, an instance of :class:`~metallic.data.dataloader.MetaDataLoader` class
val_loader : MetaDataLoader, optional
Validation data loader, an instance of :class:`~metallic.data.dataloader.MetaDataLoader` class
n_epoches : int, optional, default=100
Number of epoches
n_iters_per_epoch : int, optional, default=500
Number of the iterations per epoch
n_iters_test : int, optional, default=600
Number of the iterations during meta-test stage
logger : Logger, optional
An instance of :class:`~metallic.utils.logger.Logger` class
"""
def __init__(
self,
metalearner: MetaLearner,
train_loader: MetaDataLoader,
val_loader: Optional[MetaDataLoader] = None,
n_epoches: int = 100,
n_iters_per_epoch: int = 500,
n_iters_test: int = 600,
logger: Optional[Logger] = None
):
self.metalearner = metalearner
self.train_loader = train_loader
self.val_loader = val_loader
self.n_epoches = n_epoches
self.n_iters_per_epoch = n_iters_per_epoch
self.n_iters_test = n_iters_test
self.logger = logger
self.n_way = self.train_loader.dataset.n_way
self.k_shot = self.train_loader.dataset.task_splits['support']
[docs] def save(self, is_best: bool = False):
"""Save checkpoints."""
if self.metalearner.root and self.metalearner.save_basename:
self.metalearner.save('{0}shot_{1}way_'.format(self.n_way, self.k_shot))
# If this checkpoint is the best so far, store a copy so it
# doesn't get overwritten by a worse checkpoint.
if is_best:
best_path = self.metalearner.save(
'best_{0}shot_{1}way_'.format(self.n_way, self.k_shot)
)
print('Saved the current best checkpoint to: {0}.'.format(best_path))
[docs] def lr_schedule(self):
"""Schedule learning rate."""
if self.metalearner.lr_scheduler:
self.metalearner.lr_schedule()
[docs] def run_epoch(self, epoch: int, train: bool = True):
"""Train or evaluate an epoch."""
tracker = MetricTracker('batch_time', 'data_time', 'loss', 'accuracy')
# reset the start time
start = time.time()
loader = self.train_loader if train else self.val_loader
n_iters = self.n_iters_per_epoch if train else self.n_iters_test
stage = 'meta-train' if train else 'meta-test'
for i_iter, batch in tqdm(
enumerate(loader), total=n_iters, desc=f"Epoch [{epoch}] ({stage})"
):
# data loading time per batch
tracker.update('data_time', time.time() - start)
# get loss and accuracy
loss, accuracy = self.metalearner.step(batch, meta_train=train)
# track average loss and accuracy
tracker.update('loss', loss)
tracker.update('accuracy', accuracy)
# track average forward prop. + back prop. time per batch
tracker.update('batch_time', time.time() - start)
# reset the start time
start = time.time()
# log training status
if self.logger is not None:
self.logger.log(tracker.metrics, epoch, i_iter + 1, stage)
if (i_iter + 1) >= n_iters:
break
return tracker['accuracy'].mean
[docs] def run_train(self):
"""Run training procedure."""
best_acc = 0.
for epoch in range(1, self.n_epoches + 1):
# meta-train an epoch
recent_acc = self.run_epoch(epoch, train=True)
# meta-test an epoch, get the average accuracy over all batches
if self.val_loader is not None:
recent_acc = self.run_epoch(epoch, train=False)
# if the current model achieves the best accuracy
is_best = recent_acc > best_acc
best_acc = max(recent_acc, best_acc)
# save checkpoint
self.save(is_best)
# schedule learning rate
self.lr_schedule()