Source code for metallic.data.datasets.omniglot

import os
from collections import defaultdict
from typing import Callable, Optional, Dict, List
from PIL import ImageOps
import torch
from torch.utils.data import ConcatDataset
from torchvision.datasets.omniglot import Omniglot as TorchOmniglot

from .base import ClassDataset, MetaDataset
from .. import _utils

[docs]class OmniglotClassDataset(ClassDataset): """ A dataset composed of classes from Omniglot. Args: root (str): Root directory of dataset meta_split (str, optional, default='train'): Name of the split to be used: 'train' / 'val' / 'test use_vinyals_split (bool, optional, default=True): If ``True``, use the splits defined in [2], or use ``images_background`` for train split and ``images_evaluation`` for test split. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version target_transform (callable, optional): A function/transform that takes in the target and transforms it augmentations (list of callable, optional): A list of functions that augment the dataset with new classes. download (bool, optional, default=False): If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. """ dataset_name = 'omniglot' def __init__( self, root: str, meta_split: str = 'train', use_vinyals_split: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, augmentations: List[Callable] = None, download: bool = False ) -> None: super(OmniglotClassDataset, self).__init__( root = root, meta_split = meta_split, cache_path = self.dataset_name + '_cache.pth.tar', transform = transform, target_transform = target_transform, augmentations = augmentations ) if self.meta_split == 'val' and (not use_vinyals_split): raise ValueError( 'You must set `use_vinyals_split = True` to use the' 'meta-validation split.' ) if use_vinyals_split: self.meta_split = 'vinyals_{}'.format(meta_split) self.use_vinyals_split = use_vinyals_split self.omniglot = {} # background set self.omniglot['background'] = TorchOmniglot( root = self.root, background = True, download = download ) # evaluation set, labels start after background set self.omniglot['evaluation'] = TorchOmniglot( root = self.root, background = False, download = download, target_transform = lambda x: x + len(self.omniglot['background']._characters) ) # combine them self.dataset = ConcatDataset((self.omniglot['background'], self.omniglot['evaluation'])) self.preprocess()
[docs] def create_cache(self) -> None: self.labels = {} self.label_to_images = defaultdict(list) # create a map of target to samples for (image, label) in self.dataset: self.label_to_images[label].append(ImageOps.invert(image)) # create a list of labels for each split # eval / background split get_name = { 'train': 'background', 'test': 'evaluation' } for name in ['train', 'test']: label_list = [label for (_, label) in self.omniglot[get_name[name]]] self.labels[name] = list(set(label_list)) # Vinyals' split file_to_label = self._file_to_label(self.omniglot) for name in ['train', 'val', 'test']: split_name = 'vinyals_{}'.format(name) split = _utils.load_splits(self.dataset_name, '{0}.json'.format(name)) self.labels[split_name] = sorted([ file_to_label['/'.join([name, alphabet, character])] for (name, alphabets) in split.items() for (alphabet, characters) in alphabets.items() for character in characters ])
@staticmethod def _file_to_label(data: dict) -> Dict[str, list]: file_to_label = {} start = { 'background': 0, 'evaluation': len(data['background']._characters) } for name in ['background', 'evaluation']: for (image, label) in data[name]: filename = '/'.join([name, data[name]._characters[label - start[name]]]) file_to_label[filename] = label return file_to_label
[docs]class Omniglot(MetaDataset): """ The Omniglot introduced in [1]. It contains 1623 character classes from 50 different alphabets, each contains 20 samples. The original dataset is splited into background (train) and evaluation (test) sets. We also provide a choice to use the splits from [2]. The dataset is downloaded from `here <https://github.com/brendenlake/omniglot>`_, and the splits are taken from `here <https://github.com/tristandeleu/pytorch-meta/tree/master/torchmeta/datasets/assets/omniglot>`_. Parameters ---------- root : str Root directory of dataset n_way : int Number of the classes per tasks meta_split : str, optional, default='train' Name of the split to be used: 'train' / 'val' / 'test use_vinyals_split : bool, optional, default=True If ``True``, use the splits defined in [2], or use ``images_background`` for train split and ``images_evaluation`` for test split. k_shot_support : int, optional Number of samples per class in support set k_shot_query : int, optional Number of samples per class in query set shuffle : bool, optional, default=True If ``True``, samples in a class will be shuffled before been splited to support and query set transform : Callable, optional A function/transform that takes in an PIL image and returns a transformed version target_transform : Callable, optional A function/transform that takes in the target and transforms it augmentations : List[Callable], optional A list of functions that augment the dataset with new classes download : bool, optional, default=False If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. NOTE: ``val`` split is not available when ``use_vinyals_split`` is set to ``False``. .. admonition:: References 1. "`Human-level Concept Learning through Probabilistic Program Induction. \ <http://www.sciencemag.org/content/350/6266/1332.short>`_" \ *Brenden M. Lake, et al.* Science 2015. 2. "`Matching Networks for One Shot Learning. \ <https://arxiv.org/abs/1606.04080>`_" Oriol Vinyals, et al. NIPS 2016. """ def __init__( self, root: str, n_way: int, meta_split: str = 'train', use_vinyals_split: bool = True, k_shot_support: Optional[int] = None, k_shot_query: Optional[int] = None, shuffle: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, augmentations: Optional[List[Callable]] = None, download: bool = False ) -> None: dataset = OmniglotClassDataset( root = root, meta_split = meta_split, use_vinyals_split = use_vinyals_split, transform = transform, target_transform = target_transform, augmentations = augmentations, download = download ) super(Omniglot, self).__init__( dataset = dataset, n_way = n_way, k_shot_support = k_shot_support, k_shot_query = k_shot_query, shuffle = shuffle )