Source code for metallic.data.dataloader

from typing import Optional, Callable
from torch.utils.data import DataLoader, Sampler

from .datasets import MetaDataset
from .sampler import *
from . import _utils

[docs]class MetaDataLoader(DataLoader): def __init__( self, dataset: MetaDataset, batch_size: int = 1, shuffle: bool = True, sampler: Optional[Sampler] = None, batch_sampler: Optional[Sampler] = None, num_workers: int = 0, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0., worker_init_fn = Callable[[int], None] ) -> None: collate_fn = _utils.MetaCollate() if sampler is None: if shuffle: sampler = MetaRandomSampler(dataset) else: sampler = MetaSequentialSampler(dataset) shuffle = False super(MetaDataLoader, self).__init__( dataset = dataset, batch_size = batch_size, shuffle = shuffle, sampler = sampler, batch_sampler = batch_sampler, num_workers = num_workers, collate_fn = collate_fn, pin_memory = pin_memory, drop_last = drop_last, timeout = timeout, worker_init_fn = worker_init_fn )