Source code for metallic.metalearners.gbml.maml

from typing import Callable, Optional, Tuple
import higher
import torch
from torch import nn, optim

from .base import GBML
from ...functional import apply_grads, accum_grads

[docs]class MAML(GBML): """ Implementation of Model-Agnostic Meta-Learning (MAML) algorithm proposed in [1]. `Here <https://github.com/cbfinn/maml>`_ is the official implementation of MAML based on Tensorflow. Parameters ---------- model : torch.nn.Module Model to be wrapped in_optim : torch.optim.Optimizer Optimizer for the inner loop out_optim : torch.optim.Optimizer Optimizer for the outer loop root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function inner_steps : int, optional, defaut=1 Number of gradient descent updates in inner loop first_order : bool, optional, default=False Use the first-order approximation of MAML (FOMAML) or not device : optional Device on which the model is defined .. admonition:: References 1. "`Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. \ <https://arxiv.org/abs/1703.03400>`_" Chelsea Finn, et al. ICML 2017. """ alg_name = 'MAML' def __init__( self, model: nn.Module, in_optim: optim.Optimizer, out_optim: optim.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, first_order: bool = False, device: Optional = None ) -> None: if save_basename is None: save_basename = self.alg_name super(MAML, self).__init__( model = model, in_optim = in_optim, out_optim = out_optim, root = root, save_basename = save_basename, lr_scheduler = lr_scheduler, loss_function = loss_function, inner_steps = inner_steps, device = device ) self.first_order = first_order self.grad_list = []
[docs] def clear_before_outer_loop(self): """Initialization before each outer loop if needed.""" self.grad_list = []
[docs] def outer_loop_update(self): """Update the model's meta-parameters to optimize the query loss.""" apply_grads(self.model, accum_grads(self.grad_list)) # apply accumulated gradients to the original model parameters self.out_optim.step() # outer loop update
[docs] def compute_outer_grads( self, task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute gradients on query set.""" support_input, support_target, query_input, query_target = task # Use higher to make the model stateless and use differentiable # optimizer. So that the model's parameters can be automatically # kept copies of as they are being updated. with higher.innerloop_ctx( self.model, self.in_optim, track_higher_grads=(meta_train and (not self.first_order)) ) as (fmodel, diffopt): # fmodel: stateless version of the model # diffopt: differentiable version of the optimizer # inner loop (adapt) self.inner_loop(fmodel, diffopt, support_input, support_target) # evaluate on the query set with torch.set_grad_enabled(meta_train): query_output = fmodel(query_input) query_loss = self.loss_function(query_output, query_target) query_loss /= len(query_input) # compute gradients when in the meta-training stage if meta_train == True: # (query_loss / n_tasks).backward() outer_grad = torch.autograd.grad(query_loss / n_tasks, fmodel.parameters(time=0)) self.grad_list.append(outer_grad) return query_output, query_loss