metallic.functional¶
Loss Functions¶
-
class
metallic.functional.
ProximalRegLoss
(loss_function: Optional[Callable] = None, lamb: float = 0.1)[source]¶ Bases:
torch.nn.modules.module.Module
Add an explicitly l2 regularization term based on meta-parameters and model-parameters to the loss function. This is because we want model-parameters to retain a close dependence on meta-parameters.
This loss function has been used in [1] and [2].
- Parameters
loss_function (callable, optional) – Loss function
lamb (float, optional, float=0.1) – Regularization strength of the inner level proximal regularization
References
“Efficient Meta Learning via Minibatch Proximal Update.” Pan Zhou, et al. NIPS 2019. The supplementary file can be found here.
“Meta-Learning with Implicit Gradients.” Aravind Rajeswaran, et al. NIPS 2019.
-
training
: bool¶
Distance Functions¶
Some distance computing functions for calculating similarity between two tensors. They are useful in metric-based meta-learning algorithms.
Gradients¶
Some operations on gradients, which are are useful in gradient-based meta-learning algorithms.
Prototypes¶
-
metallic.functional.
get_prototypes
(inputs: torch.FloatTensor, targets: torch.LongTensor) → torch.FloatTensor[source]¶ Compute the prototypes for each class in the task. Each prototype is the mean vector of the embedded support points belonging to its class.
- Parameters
inputs (torch.FloatTensor) – Embeddings of the support points, with shape
(n_samples, embed_dim)
targets (torch.LongTensor) – Targets of the support points, with shape
(n_samples)
- Returns
Prototypes for each class, with shape
(n_way, embed_dim)
.- Return type
prototypes (torch.FloatTensor)