from typing import Optional
import torch
from torch import nn
from .modules import ConvGroup, Flatten, LinearBlock
[docs]class OmniglotCNN(nn.Module):
    """
    The convolutional network used for experiments on Omniglot, firstly
    introduced by [1].
    It has 4 modules with a 3 × 3 convolutions and 64 filters, followed by
    batch normalization, a ReLU nonlinearity, and 2 × 2 max-pooling.
    This network assumes the images are downsampled to 28 × 28 and have 1
    channel. Namely, the shapes of inputs are (1, 28, 28).
    Parameters
    ----------
    n_classes : int
        Size of the network's output. This corresponds to ``N`` in ``N-way``
        classification. ``None`` if the linear classifier is not needed.
    .. admonition:: References
        1. "`Matching Networks for One Shot Learning. <https://arxiv.org/abs/1606.04080>`_" \
            Oriol Vinyals, et al. NIPS 2016.
    """
    def __init__(self, n_classes: Optional[int] = None) -> None:
        super(OmniglotCNN, self).__init__()
        self.hidden_size = 64
        base = ConvGroup(
            in_channels = 1,
            hidden_size = self.hidden_size,
            layers = 4
        )
        self.encoder = nn.Sequential(
            base,  # (batch_size, 64, 28 / 16 = 1, 28 / 16 = 1)
            Flatten()  # (batch_size, 64)
        )
        self.n_classes = n_classes
        if n_classes:
            self.classifier = nn.Linear(self.hidden_size, n_classes)
            self.init_weights()
[docs]    def init_weights(self) -> None:
        self.classifier.weight.data.normal_()
        self.classifier.bias.data.mul_(0.0) 
[docs]    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            Input data (batch_size, in_channels=1, img_size=28, img_size=28)
        Returns
        -------
        output : torch.Tensor
            If ``n_classes`` is not None, return class scores ``(batch_size,
            n_classes)``, or return embedded features ``(batch_size, 64)``
        """
        output = self.encoder(x)  # (batch_size, 64)
        if self.n_classes:
            output = self.classifier(output)  # (batch_size, n_classes)
        return output  
[docs]class MiniImagenetCNN(nn.Module):
    """
    The convolutional network used for experiments on MiniImagenet, firstly
    introduced by [1].
    It has 4 modules with a 3 × 3 convolutions and 32 filters, followed by
    batch normalization, a ReLU nonlinearity, and 2 × 2 max-pooling.
    This network assumes the images are downsampled to 84 × 84 and have 3
    channel. Namely, the shapes of inputs are (3, 84, 84).
    Parameters
    ----------
    n_classes : int, optional
        Size of the network's output. This corresponds to ``N`` in ``N-way``
        classification. ``None`` if the linear classifier is not needed.
    .. admonition:: References
        1. "`Optimization as a Model for Few-Shot Learning. \
            <https://openreview.net/pdf?id=rJY0-Kcll>`_" Sachin Ravi, et al. ICLR 2017.
    """
    def __init__(self, n_classes: Optional[int] = None) -> None:
        super(OmniglotCNN, self).__init__()
        self.hidden_size = 32
        base = ConvGroup(
            in_channels = 3,
            hidden_size = self.hidden_size,
            layers = 4
        )
        self.encoder = nn.Sequential(
            base,  # (batch_size, 32, 84 / 16 = 5, 84 / 16 = 5)
            Flatten()  # (batch_size, 32 × 5 × 5 = 800)
        )
        self.n_classes = n_classes
        if n_classes:
            self.classifier = nn.Linear(5 * 5 * self.hidden_size, n_classes)
            self.init_weights()
[docs]    def init_weights(self) -> None:
        self.classifier.weight.data.normal_()
        self.classifier.bias.data.mul_(0.0) 
[docs]    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            Input data (batch_size, in_channels=3, img_size=84, img_size=84)
        Returns
        -------
        output : torch.Tensor
            If ``n_classes`` is not None, return class scores ``(batch_size,
            n_classes)``, or return embedded features ``(batch_size, 800)``.
        """
        output = self.encoder(x)  # (batch_size, 800)
        if self.n_classes:
            output = self.classifier(output)  # (batch_size, n_classes)
        return output  
[docs]class OmniglotMLP(nn.Module):
    """
    The fully-connected network used for experiments on Omniglot, firstly
    introduced by [1].
    It has 4 hidden layers with sizes 256, 128, 64, 64, each including batch
    normalization and ReLU nonlinearities, followed by a linear layer and
    softmax.
    Parameters
    ----------
    input_size : int
        Size of the network's input
    n_classes : int
        Size of the network's output. This corresponds to ``N`` in ``N-way``
        classification.
    .. admonition:: References
        1. "`Meta-Learning with Memory-Augmented Neural Networks. \
            <http://proceedings.mlr.press/v48/santoro16.pdf>`_" Adam Santoro, et al. ICML 2016.
    """
    def __init__(self, input_size: int, n_classes: int) -> None:
        super(OmniglotMLP, self).__init__()
        linear_sizes = [input_size, 256, 128, 64, 64]
        layers = [
            LinearBlock(in_size, out_size)
            for in_size, out_size in zip(linear_sizes[:-1], linear_sizes[1:])
        ]
        base = nn.Sequential(*layers)
        self.encoder = nn.Sequential(Flatten(), base)
        self.classifier = nn.Linear(linear_sizes[-1], n_classes)
        self.init_weights()
[docs]    def init_weights(self) -> None:
        self.classifier.weight.data.normal_()
        self.classifier.bias.data.mul_(0.0) 
[docs]    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.encoder(x)
        output = self.classifier(features)
        return output