Source code for metallic.metalearners.gbml.anil

from typing import Callable, Optional, Tuple
import higher
import torch
from torch import nn, optim

from .base import GBML
from ...functional import apply_grads, accum_grads

[docs]class ANIL(GBML): """ Implementation of Almost No Inner Loop (ANIL) algorithm proposed in [1], which only update the head of the neural netowork in inner loop. Parameters ---------- model : torch.nn.Module Model to be wrapped in_optim : torch.optim.Optimizer Optimizer for the inner loop out_optim : torch.optim.Optimizer Optimizer for the outer loop root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function inner_steps : int, optional, defaut=1 Number of gradient descent updates in inner loop device : optional Device on which the model is defined .. admonition:: References 1. "`Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness \ of MAML. <https://arxiv.org/abs/1909.09157>`_" Aniruddh Raghu, et al. ICLR 2020. """ alg_name = 'ANIL' def __init__( self, model: nn.Module, in_optim: optim.Optimizer, out_optim: optim.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None ) -> None: if save_basename is None: save_basename = self.alg_name super(ANIL, self).__init__( model = model, in_optim = in_optim, out_optim = out_optim, root = root, save_basename = save_basename, lr_scheduler = lr_scheduler, loss_function = loss_function, inner_steps = inner_steps, device = device ) self.encoder = model.encoder self.head = model.classifier
[docs] def compute_outer_grads( self, task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute gradients on query set.""" support_input, support_target, query_input, query_target = task with higher.innerloop_ctx( self.head, self.in_optim, copy_initial_weights=False, track_higher_grads=meta_train ) as (fhead, diffopt): with torch.no_grad(): support_feature = self.encoder(support_input) # inner loop (adapt) self.inner_loop(fhead, diffopt, support_feature, support_target) # evaluate on the query set with torch.set_grad_enabled(meta_train): quert_feature = self.encoder(query_input) query_output = fhead(quert_feature) query_loss = self.loss_function(query_output, query_target) query_loss /= len(query_input) # compute gradients when in the meta-training stage if meta_train == True: (query_loss / n_tasks).backward() return query_output, query_loss