Source code for metallic.metalearners.gbml.base

import os
from abc import ABC, abstractmethod
from typing import Callable, Optional, Tuple
import torch
from torch import nn, optim

from ..base import MetaLearner
from ...utils import get_accuracy

[docs]class GBML(MetaLearner, ABC): """ A base class for gradient-based meta-learning algorithms. Parameters ---------- model : torch.nn.Module Model to be wrapped in_optim : torch.optim.Optimizer Optimizer for the inner loop out_optim : torch.optim.Optimizer Optimizer for the outer loop root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function inner_steps : int, optional, defaut=1 Number of gradient descent updates in inner loop device : optional Device on which the model is defined. If `None`, device will be detected automatically. """ def __init__( self, model: nn.Module, in_optim: optim.Optimizer, out_optim: optim.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None ) -> None: super(GBML, self).__init__( model = model, root = root, save_basename = save_basename, lr_scheduler = lr_scheduler, loss_function = loss_function, device = device ) self.model.train() self.in_optim = in_optim self.out_optim = out_optim self.inner_steps = inner_steps
[docs] @torch.enable_grad() def inner_loop(self, fmodel, diffopt, train_input, train_target) -> None: """Inner loop update.""" for step in range(self.inner_steps): # compute loss on the support set train_output = fmodel(train_input) support_loss = self.loss_function(train_output, train_target) # update parameters diffopt.step(support_loss)
[docs] @classmethod def load(cls, model_path: str, **kwargs): """Load a trained model.""" state = torch.load(model_path) # load model and optimizers kwargs['model'] = state['model'] kwargs['in_optim'] = state['in_optim'] kwargs['out_optim'] = state['out_optim'] # model name and save path if 'root' not in kwargs: kwargs['root'] = os.path.dirname(model_path) if 'save_basename' not in kwargs: kwargs['save_basename'] = os.path.basename(model_path) return cls(**kwargs)
[docs] def save(self, prefix: Optional[str] = None) -> str: """Save the trained model.""" if self.root is None or self.save_basename is None: raise RuntimeError('The root directory or save basename of the' 'checkpoints is not defined.') state = { 'model': self.model, 'in_optim': self.in_optim, 'out_optim': self.out_optim } name = self.save_basename if prefix is not None: name = prefix + name + '.pth.tar' path = os.path.join(self.root, name) torch.save(state, os.path.join(self.root, name)) return path
[docs] @abstractmethod def compute_outer_grads( self, batch: Tuple[torch.Tensor], meta_train: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute gradients on query set.""" pass
[docs] def clear_before_outer_loop(self): """Initialization before each outer loop if needed.""" pass
[docs] def outer_loop_update(self): """Update the model's meta-parameters to optimize the query loss.""" self.out_optim.step()
[docs] def step(self, batch: dict, meta_train: bool = True) -> Tuple[float]: """Outer loop""" # clear gradient of last batch self.out_optim.zero_grad() # loss and accuracy on query set (outer loop) outer_loss, outer_accuracy = 0., 0. self.clear_before_outer_loop() # get task batch task_batch, n_tasks = self.get_tasks(batch) for task in task_batch: # compute outer loop gradients query_output, query_loss = self.compute_outer_grads(task, n_tasks, meta_train) # find accuracy on query set _, _, _, query_target = task query_accuracy = get_accuracy(query_output, query_target) # record losses and accuracy outer_loss += query_loss.detach().item() outer_accuracy += query_accuracy.item() # When in the meta-training stage, update the model's meta-parameters to # optimize the query loss across all of the tasks sampled in this batch. if meta_train == True: self.outer_loop_update() # average loss and acccuracy outer_loss /= n_tasks outer_accuracy /= n_tasks return outer_loss, outer_accuracy