Source code for metallic.data.datasets.miniimagenet

import os
from typing import Callable, Optional, Tuple, Any, List
from collections import defaultdict
import pickle
from PIL import Image
import torch
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset as TorchDataset
from torchvision.datasets.utils import download_file_from_google_drive, \
    extract_archive, check_integrity

from .base import ClassDataset, MetaDataset

[docs]class MiniImageNetClassDataset(ClassDataset): """ A dataset composed of classes from mini-ImageNet. Parameters ---------- root : str Root directory of dataset meta_split : str, optional, default='train' Name of the split to be used: 'train' / 'val' / 'test transform : Callable, optional A function/transform that takes in an PIL image and returns a transformed version target_transform : Callable, optional A function/transform that takes in the target and transforms it augmentations : List[Callable], optional A list of functions that augment the dataset with new classes. download : bool, optional, default=False If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. """ dataset_name = 'mini-imagenet' google_drive_id = '16V_ZlkW4SsnNDtnGmaBRq2OoPmUOc5mY' zip_md5 = 'b38f1eb4251fb9459ecc8e7febf9b2eb' pkl_name = 'mini-imagenet-cache-{0}.pkl' def __init__( self, root: str, meta_split: str = 'train', transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, augmentations: Optional[List[Callable]] = None, download: bool = False ) -> None: super(MiniImageNetClassDataset, self).__init__( root = os.path.join(root, self.dataset_name), meta_split = meta_split, cache_path = self.dataset_name + '_cache.pth.tar', transform = transform, target_transform = target_transform, augmentations = augmentations ) if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') self.preprocess() def _check_integrity(self) -> bool: return check_integrity(os.path.join(self.root, self.dataset_name + '.tar.gz'), self.zip_md5)
[docs] def download(self) -> None: """Download file from Google drive.""" if self._check_integrity(): print('Files already downloaded and verified') return filename = self.dataset_name + '.tar.gz' download_file_from_google_drive( file_id = self.google_drive_id, root = self.root, filename = filename, md5 = self.zip_md5 ) archive = os.path.join(self.root, filename) print("Extracting {} to {}".format(archive, self.root)) extract_archive(archive, self.root)
[docs] def create_cache(self) -> None: self.labels = {} self.label_to_images = defaultdict(list) cumulative_size = 0 for split in ["train", "val", "test"]: pkl_path = os.path.join(self.root, self.pkl_name.format(split)) with open(pkl_path, 'rb') as f: data = pickle.load(f) images, targets = data['image_data'], data['class_dict'] n_classes = len(targets) categorical = (torch.randperm(len(targets)) + cumulative_size).tolist() to_categorical = dict( (target, categorical[i]) for (i, target) in enumerate(list(targets.keys())) ) self.labels[split] = categorical for label, indices in targets.items(): self.label_to_images[to_categorical[label]] = [ Image.fromarray(image) for image in images[indices] ] cumulative_size += n_classes
[docs]class MiniImageNet(MetaDataset): """ The mini-ImageNet dataset introduced in [1]. It samples 100 classed from ImageNet (ILSVRC-2012), in which 64 for training, 16 for validation, and 20 for testing. Each of the class contains 600 samples. The dataset is downloaded from `here <https://github.com/renmengye/few-shot-ssl-public/>`_. NOTE: [1] didn't released their splits at first, so [2] created their own splits. Here we use the splits from [2]. Parameters ---------- root : str Root directory of dataset n_way : int Number of the classes per tasks meta_split : str, optional, default='train' Name of the split to be used: 'train' / 'val' / 'test k_shot_support : int, optional Number of samples per class in support set k_shot_query : int, optional Number of samples per class in query set shuffle : bool, optional, default=True If ``True``, samples in a class will be shuffled before been splited to support and query set transform : Callable, optional A function/transform that takes in an PIL image and returns a transformed version target_transform : Callable, optional A function/transform that takes in the target and transforms it augmentations : List[Callable], optional A list of functions that augment the dataset with new classes download : bool, optional, default=False If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. .. admonition:: References 1. "`Matching Networks for One Shot Learning. \ <https://arxiv.org/abs/1606.04080>`_" Oriol Vinyals, et al. NIPS 2016. 2. "`Optimization as a Model for Few-Shot Learning. \ <https://openreview.net/pdf?id=rJY0-Kcll>`_" Sachin Ravi, et al. ICLR 2017. """ def __init__( self, root: str, n_way: int, meta_split: str = 'train', k_shot_support: Optional[int] = None, k_shot_query: Optional[int] = None, shuffle: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, augmentations: Optional[List[Callable]] = None, download: bool = False ) -> None: dataset = MiniImageNetClassDataset( root = root, meta_split = meta_split, transform = transform, target_transform = target_transform, augmentations = augmentations, download = download ) super(MiniImageNet, self).__init__( dataset = dataset, n_way = n_way, k_shot_support = k_shot_support, k_shot_query = k_shot_query, shuffle = shuffle )