Source code for metallic.functional.gradient

import torch
from torch import nn
from typing import Optional, Sequence, List

[docs]def apply_grads(model: nn.Module, grads: Sequence[torch.Tensor]) -> None: """ Map a list of gradients to a model. Parameters ---------- grads : Sequence[torch.Tensor] List of gradient for each model parameter """ if not len(grads) == len(list(model.parameters())): msg = 'WARNING: Parameters and gradients have different length. (' msg += str(len(list(model.parameters()))) + ' vs ' + str(len(grads)) + ')' print(msg) for param, grad in zip(model.parameters(), grads): if grad is not None: if param.grad is None: param.grad = grad.clone() else: param.grad += grad.clone()
def accum_grads(grads: Sequence[Sequence[torch.Tensor]]) -> List[torch.Tensor]: """ Compute accumulated gradients Parameters ---------- grads : Sequence[Sequence[torch.Tensor]] List of gradients on all previous tasks Returns ------- sum_grads : List[torch.Tensor] Accumulated gradients """ sum_grads = [torch.stack(g).sum(dim=0) for g in zip(*grads)] return sum_grads