Source code for metallic.functional.prototype

import torch

[docs]def get_prototypes( inputs: torch.FloatTensor, targets: torch.LongTensor ) -> torch.FloatTensor: """ Compute the **prototypes** for each class in the task. Each prototype is the mean vector of the embedded support points belonging to its class. Args: inputs (torch.FloatTensor): Embeddings of the support points, with shape ``(n_samples, embed_dim)`` targets (torch.LongTensor): Targets of the support points, with shape ``(n_samples)`` Returns: prototypes (torch.FloatTensor): Prototypes for each class, with shape \ ``(n_way, embed_dim)``. """ n_way = torch.unique(targets).size(0) # number of classes per task k_shot = targets.size(0) // n_way # number of samples per class embed_dim = inputs.size(-1) # embedding size indices = targets.unsqueeze(-1).expand_as(inputs) # (n_samples, embed_dim) prototypes = inputs.new_zeros(n_way, embed_dim) # (n_way, embed_dim) prototypes.scatter_add_(0, indices, inputs).div_(k_shot) return prototypes