metallic.metalearners

class metallic.metalearners.MetaLearner(model: torch.nn.modules.module.Module, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, device: Optional = None)[source]

Bases: abc.ABC

A base class for all meta-learning algorithms.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • device (optional) – Device on which the model is defined. If None, device will be detected automatically.

get_tasks(batch: dict)tuple[source]
abstract classmethod load(model_path: str, **kwargs)[source]

Load a trained model.

lr_schedule()None[source]

Schedule learning rate.

abstract save(prefix: Optional[str] = None)str[source]

Save the trained model.

abstract step(batch: dict, meta_train: bool = True)Tuple[float][source]

Gradient-Based Meta-Learning

class metallic.metalearners.GBML(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]

Bases: metallic.metalearners.base.MetaLearner, abc.ABC

A base class for gradient-based meta-learning algorithms.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • device (optional) – Device on which the model is defined. If None, device will be detected automatically.

clear_before_outer_loop()[source]

Initialization before each outer loop if needed.

abstract compute_outer_grads(batch: Tuple[torch.Tensor], meta_train: bool = True)Tuple[torch.Tensor, torch.Tensor][source]

Compute gradients on query set.

inner_loop(fmodel, diffopt, train_input, train_target)None[source]

Inner loop update.

classmethod load(model_path: str, **kwargs)[source]

Load a trained model.

outer_loop_update()[source]

Update the model’s meta-parameters to optimize the query loss.

save(prefix: Optional[str] = None)str[source]

Save the trained model.

step(batch: dict, meta_train: bool = True)Tuple[float][source]

Outer loop

class metallic.metalearners.MAML(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, first_order: bool = False, device: Optional = None)[source]

Bases: metallic.metalearners.gbml.base.GBML

Implementation of Model-Agnostic Meta-Learning (MAML) algorithm proposed in [1].

Here is the official implementation of MAML based on Tensorflow.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • first_order (bool, optional, default=False) – Use the first-order approximation of MAML (FOMAML) or not

  • device (optional) – Device on which the model is defined

References

  1. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” Chelsea Finn, et al. ICML 2017.

alg_name = 'MAML'
clear_before_outer_loop()[source]

Initialization before each outer loop if needed.

compute_outer_grads(task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True)Tuple[torch.Tensor, torch.Tensor][source]

Compute gradients on query set.

outer_loop_update()[source]

Update the model’s meta-parameters to optimize the query loss.

class metallic.metalearners.FOMAML(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]

Bases: metallic.metalearners.gbml.maml.MAML

Implementation of Fisrt-Order Model-Agnostic Meta-Learning (FOMAML) algorithm proposed in [1]. In FOMAML, the second derivatives in outer loop are omitted, which means the gradients are directly computed on the fast parameters.

Here is the official implementation of MAML based on Tensorflow.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • device (optional) – Device on which the model is defined

References

  1. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” Chelsea Finn, et al. ICML 2017.

alg_name = 'FOMAML'
class metallic.metalearners.Reptile(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 5, device: Optional = None)[source]

Bases: metallic.metalearners.gbml.base.GBML

Implementation of Reptile algorithm proposed in [1].

Here is the official implementation of Reptile based on Tensorflow.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • device (optional) – Device on which the model is defined

References

  1. On First-Order Meta-Learning Algorithms.” Alex Nichol, et al. arxiv 2018.

alg_name = 'Reptile'
clear_before_outer_loop()[source]

Initialization before each outer loop if needed.

compute_outer_grads(task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True)Tuple[torch.Tensor, torch.Tensor][source]

Compute gradients on query set.

outer_loop_update()[source]

Update the model’s meta-parameters to optimize the query loss.

class metallic.metalearners.MinibatchProx(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 5, lamb: float = 0.1, device: Optional = None)[source]

Bases: metallic.metalearners.gbml.reptile.Reptile

Implementation of MinibatchProx algorithm proposed in [1].

Here is the official implementation of MinibatchProx based on Tensorflow.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function to be wrapped in metallic.functional.ProximalRegLoss()

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • device (optional) – Device on which the model is defined

References

  1. Efficient Meta Learning via Minibatch Proximal Update.” Pan Zhou, et al. NIPS 2019.

alg_name = 'MinibatchProx'
inner_loop(fmodel, diffopt, train_input, train_target)None[source]

Inner loop update.

class metallic.metalearners.ANIL(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]

Bases: metallic.metalearners.gbml.base.GBML

Implementation of Almost No Inner Loop (ANIL) algorithm proposed in [1], which only update the head of the neural netowork in inner loop.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • in_optim (torch.optim.Optimizer) – Optimizer for the inner loop

  • out_optim (torch.optim.Optimizer) – Optimizer for the outer loop

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop

  • device (optional) – Device on which the model is defined

alg_name = 'ANIL'
compute_outer_grads(task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True)Tuple[torch.Tensor, torch.Tensor][source]

Compute gradients on query set.

Metric-Based Meta-Learning

class metallic.metalearners.ProtoNet(model: torch.nn.modules.module.Module, optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, distance: str = 'euclidean', device: Optional = None)[source]

Bases: metallic.metalearners.mbml.base.MBML

Implementation of Prototypical Networks proposed in [1].

Here is the official implementation of Prototypical Networks based on PyTorch.

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • optim (torch.optim.Optimizer) – Optimizer

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • distance (str, optional, default='euclidean') – Type of distance function to be used for computing similarity

  • device (optional) – Device on which the model is defined. If None, device will be detected automatically.

References

  1. Prototypical Networks for Few-shot Learning.” Jake Snell, et al. NIPS 2017.

alg_name = 'ProtoNet'
single_task(task: Tuple[torch.Tensor], meta_train: bool = True)Tuple[float][source]
class metallic.metalearners.MatchNet(model: torch.nn.modules.module.Module, optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, distance: str = 'cosine', device: Optional = None)[source]

Bases: metallic.metalearners.mbml.base.MBML

Implementation of Matching Networks proposed in [1].

Parameters
  • model (torch.nn.Module) – Model to be wrapped

  • optim (torch.optim.Optimizer) – Optimizer

  • root (str) – Root directory to save checkpoints

  • save_basename (str, optional) – Base name of the saved checkpoints

  • lr_scheduler (callable, optional) – Learning rate scheduler

  • loss_function (callable, optional) – Loss function

  • distance (str, optional, default='cosine') – Type of distance function to be used for computing similarity

  • device (optional) – Device on which the model is defined. If None, device will be detected automatically.

References

  1. Matching Networks for One Shot Learning.” Oriol Vinyals, et al. NIPS 2016.

alg_name = 'MatchNet'
attention(distance: torch.FloatTensor, targets: torch.LongTensor)torch.FloatTensor[source]

An attention kernel which is served as a classifier. It defines a probability distribution over output labels given a query example.

The classifier output is defined as a weighted sum of labels of support points, and the weights should be proportional to the similarity between support and query embeddings.

Parameters
  • distance (torch.FloatTensor) – Similarity between support points embeddings and query points embeddings, with shape (n_samples_query, n_samples_support)

  • targets (torch.LongTensor) – Targets of the support points, with shape (n_samples_support)

Returns

pred_pdf – Probability distribution over output labels, with shape (n_samples_query, n_way)

Return type

torch.FloatTensor

single_task(task: Tuple[torch.Tensor], meta_train: bool = True)Tuple[float][source]