Source code for metallic.metalearners.gbml.minibatchprox

from typing import Callable, Optional
import torch
from torch import nn, optim

from .reptile import Reptile
from ...functional import ProximalRegLoss

[docs]class MinibatchProx(Reptile): """ Implementation of MinibatchProx algorithm proposed in [1]. `Here <https://panzhous.github.io/assets/code/MetaMinibatchProx.zip>`_ is the official implementation of MinibatchProx based on Tensorflow. Parameters ---------- model : torch.nn.Module Model to be wrapped in_optim : torch.optim.Optimizer Optimizer for the inner loop out_optim : torch.optim.Optimizer Optimizer for the outer loop root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function to be wrapped in :func:`metallic.functional.ProximalRegLoss()` inner_steps : int, optional, defaut=1 Number of gradient descent updates in inner loop device : optional Device on which the model is defined .. admonition:: References 1. "`Efficient Meta Learning via Minibatch Proximal Update. \ <https://panzhous.github.io/assets/pdf/2019-NIPS-metaleanring.pdf>`_" \ Pan Zhou, et al. NIPS 2019. """ alg_name = 'MinibatchProx' def __init__( self, model: nn.Module, in_optim: optim.Optimizer, out_optim: optim.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 5, lamb: float = 0.1, device: Optional = None ) -> None: if save_basename is None: save_basename = self.alg_name for g in out_optim.param_groups: g['lr'] = g['lr'] * lamb super(MinibatchProx, self).__init__( model = model, in_optim = in_optim, out_optim = out_optim, root = root, save_basename = save_basename, lr_scheduler = lr_scheduler, loss_function = loss_function, inner_steps = inner_steps, device = device ) self.reg_loss_function = ProximalRegLoss(self.loss_function, lamb)
[docs] @torch.enable_grad() def inner_loop(self, fmodel, diffopt, train_input, train_target) -> None: """Inner loop update.""" # record meta-parameters init_params = [ p.detach().clone().requires_grad_(True) for p in fmodel.parameters() ] for step in range(self.inner_steps): train_output = fmodel(train_input) params = list(fmodel.parameters()) # model-parameters support_loss = self.reg_loss_function( train_output, train_target, init_params, params ) diffopt.step(support_loss)