Source code for metallic.functional.loss

from typing import Callable, Optional
from torch import Tensor, nn

[docs]class ProximalRegLoss(nn.Module): """ Add an explicitly l2 regularization term based on meta-parameters and model-parameters to the loss function. This is because we want model-parameters to retain a close dependence on meta-parameters. This loss function has been used in [1] and [2]. Args: loss_function (callable, optional): Loss function lamb (float, optional, float=0.1): Regularization strength of the inner level proximal regularization .. admonition:: References 1. "`Efficient Meta Learning via Minibatch Proximal Update. \ <https://panzhous.github.io/assets/pdf/2019-NIPS-metaleanring.pdf>`_" \ Pan Zhou, et al. NIPS 2019. The supplementary file can be found \ `here <https://panzhous.github.io/assets/pdf/2019-NIPS-metaleanring-supplementary.pdf>`_. 2. "`Meta-Learning with Implicit Gradients. <https://arxiv.org/abs/1909.04630>`_" \ Aravind Rajeswaran, et al. NIPS 2019. """ def __init__( self, loss_function: Optional[Callable] = None, lamb: float = 0.1 ) -> None: super(ProximalRegLoss, self).__init__() self.loss_function = loss_function self.lamb = lamb def __call__( self, input: Tensor, target: Tensor, init_params: Tensor, params: Tensor ) -> Tensor: sq_diff = sum([ ((init_param - param) ** 2).sum() for init_param, param in zip(init_params, params) ]) reg_term = 0.5 * self.lamb * sq_diff return self.loss_function(input, target) + reg_term