Source code for metallic.data.sampler
import random
from itertools import combinations
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .datasets import MetaDataset
[docs]class MetaSequentialSampler(SequentialSampler):
def __init__(self, data_source: MetaDataset):
super(MetaSequentialSampler, self).__init__(data_source)
def __iter__(self):
n_classes = self.data_source.n_classes
n_way = self.data_source.n_way
return combinations(range(n_classes), n_way)
[docs]class MetaRandomSampler(RandomSampler):
def __init__(self, data_source: MetaDataset):
super(MetaRandomSampler, self).__init__(data_source, replacement=True)
def __iter__(self):
n_classes = self.data_source.n_classes
n_way = self.data_source.n_way
for _ in combinations(range(n_classes), n_way):
yield tuple(random.sample(range(n_classes), n_way))