Source code for metallic.metalearners.base

import os
from abc import ABC, abstractmethod
from typing import Callable, Optional, Tuple
import torch
from torch import nn, optim

[docs]class MetaLearner(ABC): """ A base class for all meta-learning algorithms. Parameters ---------- model : torch.nn.Module Model to be wrapped root : str Root directory to save checkpoints save_basename : str, optional Base name of the saved checkpoints lr_scheduler : callable, optional Learning rate scheduler loss_function : callable, optional Loss function device : optional Device on which the model is defined. If `None`, device will be detected automatically. """ def __init__( self, model: nn.Module, root: Optional[str] = None, save_basename: Optional[str] = None, lr_scheduler: Optional[Callable] = None, loss_function: Optional[Callable] = None, device: Optional = None ) -> None: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device self.model = model.to(device) self.root = os.path.expanduser(root) self.save_basename = save_basename self.lr_scheduler = lr_scheduler if loss_function is None: loss_function = nn.CrossEntropyLoss() self.loss_function = loss_function
[docs] def get_tasks(self, batch: dict) -> tuple: # support set support_inputs, support_targets = batch['support'] support_inputs = support_inputs.to(self.device) support_targets = support_targets.to(self.device) # query set query_inputs, query_targets = batch['query'] query_inputs = query_inputs.to(self.device) query_targets = query_targets.to(self.device) # number of tasks n_tasks = query_targets.size(0) task_batch = zip( support_inputs, support_targets, query_inputs, query_targets ) return task_batch, n_tasks
[docs] def lr_schedule(self) -> None: """Schedule learning rate.""" self.lr_scheduler.step()
[docs] @classmethod @abstractmethod def load(cls, model_path: str, **kwargs): """Load a trained model.""" pass
[docs] @abstractmethod def save(self, prefix: Optional[str] = None) -> str: """Save the trained model.""" pass
[docs] @abstractmethod def step(self, batch: dict, meta_train: bool = True) -> Tuple[float]: pass