Source code for metallic.models.simple

from typing import Optional
import torch
from torch import nn
from .modules import ConvGroup, Flatten, LinearBlock

[docs]class OmniglotCNN(nn.Module): """ The convolutional network used for experiments on Omniglot, firstly introduced by [1]. It has 4 modules with a 3 × 3 convolutions and 64 filters, followed by batch normalization, a ReLU nonlinearity, and 2 × 2 max-pooling. This network assumes the images are downsampled to 28 × 28 and have 1 channel. Namely, the shapes of inputs are (1, 28, 28). Parameters ---------- n_classes : int Size of the network's output. This corresponds to ``N`` in ``N-way`` classification. ``None`` if the linear classifier is not needed. .. admonition:: References 1. "`Matching Networks for One Shot Learning. <https://arxiv.org/abs/1606.04080>`_" \ Oriol Vinyals, et al. NIPS 2016. """ def __init__(self, n_classes: Optional[int] = None) -> None: super(OmniglotCNN, self).__init__() self.hidden_size = 64 base = ConvGroup( in_channels = 1, hidden_size = self.hidden_size, layers = 4 ) self.encoder = nn.Sequential( base, # (batch_size, 64, 28 / 16 = 1, 28 / 16 = 1) Flatten() # (batch_size, 64) ) self.n_classes = n_classes if n_classes: self.classifier = nn.Linear(self.hidden_size, n_classes) self.init_weights()
[docs] def init_weights(self) -> None: self.classifier.weight.data.normal_() self.classifier.bias.data.mul_(0.0)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x : torch.Tensor Input data (batch_size, in_channels=1, img_size=28, img_size=28) Returns ------- output : torch.Tensor If ``n_classes`` is not None, return class scores ``(batch_size, n_classes)``, or return embedded features ``(batch_size, 64)`` """ output = self.encoder(x) # (batch_size, 64) if self.n_classes: output = self.classifier(output) # (batch_size, n_classes) return output
[docs]class MiniImagenetCNN(nn.Module): """ The convolutional network used for experiments on MiniImagenet, firstly introduced by [1]. It has 4 modules with a 3 × 3 convolutions and 32 filters, followed by batch normalization, a ReLU nonlinearity, and 2 × 2 max-pooling. This network assumes the images are downsampled to 84 × 84 and have 3 channel. Namely, the shapes of inputs are (3, 84, 84). Parameters ---------- n_classes : int, optional Size of the network's output. This corresponds to ``N`` in ``N-way`` classification. ``None`` if the linear classifier is not needed. .. admonition:: References 1. "`Optimization as a Model for Few-Shot Learning. \ <https://openreview.net/pdf?id=rJY0-Kcll>`_" Sachin Ravi, et al. ICLR 2017. """ def __init__(self, n_classes: Optional[int] = None) -> None: super(OmniglotCNN, self).__init__() self.hidden_size = 32 base = ConvGroup( in_channels = 3, hidden_size = self.hidden_size, layers = 4 ) self.encoder = nn.Sequential( base, # (batch_size, 32, 84 / 16 = 5, 84 / 16 = 5) Flatten() # (batch_size, 32 × 5 × 5 = 800) ) self.n_classes = n_classes if n_classes: self.classifier = nn.Linear(5 * 5 * self.hidden_size, n_classes) self.init_weights()
[docs] def init_weights(self) -> None: self.classifier.weight.data.normal_() self.classifier.bias.data.mul_(0.0)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x : torch.Tensor Input data (batch_size, in_channels=3, img_size=84, img_size=84) Returns ------- output : torch.Tensor If ``n_classes`` is not None, return class scores ``(batch_size, n_classes)``, or return embedded features ``(batch_size, 800)``. """ output = self.encoder(x) # (batch_size, 800) if self.n_classes: output = self.classifier(output) # (batch_size, n_classes) return output
[docs]class OmniglotMLP(nn.Module): """ The fully-connected network used for experiments on Omniglot, firstly introduced by [1]. It has 4 hidden layers with sizes 256, 128, 64, 64, each including batch normalization and ReLU nonlinearities, followed by a linear layer and softmax. Parameters ---------- input_size : int Size of the network's input n_classes : int Size of the network's output. This corresponds to ``N`` in ``N-way`` classification. .. admonition:: References 1. "`Meta-Learning with Memory-Augmented Neural Networks. \ <http://proceedings.mlr.press/v48/santoro16.pdf>`_" Adam Santoro, et al. ICML 2016. """ def __init__(self, input_size: int, n_classes: int) -> None: super(OmniglotMLP, self).__init__() linear_sizes = [input_size, 256, 128, 64, 64] layers = [ LinearBlock(in_size, out_size) for in_size, out_size in zip(linear_sizes[:-1], linear_sizes[1:]) ] base = nn.Sequential(*layers) self.encoder = nn.Sequential(Flatten(), base) self.classifier = nn.Linear(linear_sizes[-1], n_classes) self.init_weights()
[docs] def init_weights(self) -> None: self.classifier.weight.data.normal_() self.classifier.bias.data.mul_(0.0)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.encoder(x) output = self.classifier(features) return output