metallic.metalearners¶
-
class
metallic.metalearners.
MetaLearner
(model: torch.nn.modules.module.Module, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, device: Optional = None)[source]¶ Bases:
abc.ABC
A base class for all meta-learning algorithms.
- Parameters
model (torch.nn.Module) – Model to be wrapped
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
device (optional) – Device on which the model is defined. If None, device will be detected automatically.
Gradient-Based Meta-Learning¶
-
class
metallic.metalearners.
GBML
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]¶ Bases:
metallic.metalearners.base.MetaLearner
,abc.ABC
A base class for gradient-based meta-learning algorithms.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
device (optional) – Device on which the model is defined. If None, device will be detected automatically.
-
class
metallic.metalearners.
MAML
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, first_order: bool = False, device: Optional = None)[source]¶ Bases:
metallic.metalearners.gbml.base.GBML
Implementation of Model-Agnostic Meta-Learning (MAML) algorithm proposed in [1].
Here is the official implementation of MAML based on Tensorflow.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
first_order (bool, optional, default=False) – Use the first-order approximation of MAML (FOMAML) or not
device (optional) – Device on which the model is defined
References
“Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” Chelsea Finn, et al. ICML 2017.
-
alg_name
= 'MAML'¶
-
class
metallic.metalearners.
FOMAML
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]¶ Bases:
metallic.metalearners.gbml.maml.MAML
Implementation of Fisrt-Order Model-Agnostic Meta-Learning (FOMAML) algorithm proposed in [1]. In FOMAML, the second derivatives in outer loop are omitted, which means the gradients are directly computed on the fast parameters.
Here is the official implementation of MAML based on Tensorflow.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
device (optional) – Device on which the model is defined
References
“Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” Chelsea Finn, et al. ICML 2017.
-
alg_name
= 'FOMAML'¶
-
class
metallic.metalearners.
Reptile
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 5, device: Optional = None)[source]¶ Bases:
metallic.metalearners.gbml.base.GBML
Implementation of Reptile algorithm proposed in [1].
Here is the official implementation of Reptile based on Tensorflow.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
device (optional) – Device on which the model is defined
References
“On First-Order Meta-Learning Algorithms.” Alex Nichol, et al. arxiv 2018.
-
alg_name
= 'Reptile'¶
-
class
metallic.metalearners.
MinibatchProx
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 5, lamb: float = 0.1, device: Optional = None)[source]¶ Bases:
metallic.metalearners.gbml.reptile.Reptile
Implementation of MinibatchProx algorithm proposed in [1].
Here is the official implementation of MinibatchProx based on Tensorflow.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function to be wrapped in
metallic.functional.ProximalRegLoss()
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
device (optional) – Device on which the model is defined
References
“Efficient Meta Learning via Minibatch Proximal Update.” Pan Zhou, et al. NIPS 2019.
-
alg_name
= 'MinibatchProx'¶
-
class
metallic.metalearners.
ANIL
(model: torch.nn.modules.module.Module, in_optim: torch.optim.optimizer.Optimizer, out_optim: torch.optim.optimizer.Optimizer, root: str, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, inner_steps: int = 1, device: Optional = None)[source]¶ Bases:
metallic.metalearners.gbml.base.GBML
Implementation of Almost No Inner Loop (ANIL) algorithm proposed in [1], which only update the head of the neural netowork in inner loop.
- Parameters
model (torch.nn.Module) – Model to be wrapped
in_optim (torch.optim.Optimizer) – Optimizer for the inner loop
out_optim (torch.optim.Optimizer) – Optimizer for the outer loop
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
inner_steps (int, optional, defaut=1) – Number of gradient descent updates in inner loop
device (optional) – Device on which the model is defined
References
“Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.” Aniruddh Raghu, et al. ICLR 2020.
-
alg_name
= 'ANIL'¶
Metric-Based Meta-Learning¶
-
class
metallic.metalearners.
ProtoNet
(model: torch.nn.modules.module.Module, optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, distance: str = 'euclidean', device: Optional = None)[source]¶ Bases:
metallic.metalearners.mbml.base.MBML
Implementation of Prototypical Networks proposed in [1].
Here is the official implementation of Prototypical Networks based on PyTorch.
- Parameters
model (torch.nn.Module) – Model to be wrapped
optim (torch.optim.Optimizer) – Optimizer
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
distance (str, optional, default='euclidean') – Type of distance function to be used for computing similarity
device (optional) – Device on which the model is defined. If None, device will be detected automatically.
References
“Prototypical Networks for Few-shot Learning.” Jake Snell, et al. NIPS 2017.
-
alg_name
= 'ProtoNet'¶
-
class
metallic.metalearners.
MatchNet
(model: torch.nn.modules.module.Module, optim: torch.optim.optimizer.Optimizer, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, distance: str = 'cosine', device: Optional = None)[source]¶ Bases:
metallic.metalearners.mbml.base.MBML
Implementation of Matching Networks proposed in [1].
- Parameters
model (torch.nn.Module) – Model to be wrapped
optim (torch.optim.Optimizer) – Optimizer
root (str) – Root directory to save checkpoints
save_basename (str, optional) – Base name of the saved checkpoints
lr_scheduler (callable, optional) – Learning rate scheduler
loss_function (callable, optional) – Loss function
distance (str, optional, default='cosine') – Type of distance function to be used for computing similarity
device (optional) – Device on which the model is defined. If None, device will be detected automatically.
References
“Matching Networks for One Shot Learning.” Oriol Vinyals, et al. NIPS 2016.
-
alg_name
= 'MatchNet'¶
-
attention
(distance: torch.FloatTensor, targets: torch.LongTensor) → torch.FloatTensor[source]¶ An attention kernel which is served as a classifier. It defines a probability distribution over output labels given a query example.
The classifier output is defined as a weighted sum of labels of support points, and the weights should be proportional to the similarity between support and query embeddings.
- Parameters
distance (torch.FloatTensor) – Similarity between support points embeddings and query points embeddings, with shape
(n_samples_query, n_samples_support)
targets (torch.LongTensor) – Targets of the support points, with shape
(n_samples_support)
- Returns
pred_pdf – Probability distribution over output labels, with shape
(n_samples_query, n_way)
- Return type
torch.FloatTensor