metallic.functional

Loss Functions

class metallic.functional.ProximalRegLoss(loss_function: Optional[Callable] = None, lamb: float = 0.1)[source]

Bases: torch.nn.modules.module.Module

Add an explicitly l2 regularization term based on meta-parameters and model-parameters to the loss function. This is because we want model-parameters to retain a close dependence on meta-parameters.

This loss function has been used in [1] and [2].

Parameters
  • loss_function (callable, optional) – Loss function

  • lamb (float, optional, float=0.1) – Regularization strength of the inner level proximal regularization

References

  1. Efficient Meta Learning via Minibatch Proximal Update.” Pan Zhou, et al. NIPS 2019. The supplementary file can be found here.

  2. Meta-Learning with Implicit Gradients.” Aravind Rajeswaran, et al. NIPS 2019.

training: bool

Distance Functions

Some distance computing functions for calculating similarity between two tensors. They are useful in metric-based meta-learning algorithms.

Gradients

Some operations on gradients, which are are useful in gradient-based meta-learning algorithms.

metallic.functional.apply_grads(model: torch.nn.modules.module.Module, grads: Sequence[torch.Tensor])None[source]

Map a list of gradients to a model.

Parameters

grads (Sequence[torch.Tensor]) – List of gradient for each model parameter

Prototypes

metallic.functional.get_prototypes(inputs: torch.FloatTensor, targets: torch.LongTensor)torch.FloatTensor[source]

Compute the prototypes for each class in the task. Each prototype is the mean vector of the embedded support points belonging to its class.

Parameters
  • inputs (torch.FloatTensor) – Embeddings of the support points, with shape (n_samples, embed_dim)

  • targets (torch.LongTensor) – Targets of the support points, with shape (n_samples)

Returns

Prototypes for each class, with shape (n_way, embed_dim).

Return type

prototypes (torch.FloatTensor)