Source code for metallic.metalearners.gbml.fomaml

from typing import Callable, Optional
from torch import nn, optim
from .maml import MAML

[docs]class FOMAML(MAML): """ Implementation of Fisrt-Order Model-Agnostic Meta-Learning (FOMAML) algorithm proposed in [1]. In FOMAML, the second derivatives in outer loop are omitted, which means the gradients are directly computed on the fast parameters. `Here <https://github.com/cbfinn/maml>`_ is the official implementation of MAML based on Tensorflow. Parameters ---------- model : torch.nn.Module Model to be wrapped in_optim : torch.optim.Optimizer Optimizer for the inner loop out_optim : torch.optim.Optimizer Optimizer for the outer loop root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function inner_steps : int, optional, defaut=1 Number of gradient descent updates in inner loop device : optional Device on which the model is defined .. admonition:: References 1. "`Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. \ <https://arxiv.org/abs/1703.03400>`_" Chelsea Finn, et al. ICML 2017. """ alg_name = 'FOMAML' def __init__( self, model: nn.Module, in_optim: optim.Optimizer, out_optim: optim.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None ) -> None: super(FOMAML, self).__init__( model = model, in_optim = in_optim, out_optim = out_optim, root = root, save_basename = save_basename, lr_scheduler = lr_scheduler, loss_function = loss_function, inner_steps = inner_steps, first_order = True, device = device )