from numbers import Number
from typing import Union, Dict
from collections import OrderedDict
import numpy as np
import torch
[docs]def get_accuracy(scores: torch.Tensor, targets: torch.Tensor) -> float:
"""Compute accuracy using predicted scores and targets."""
_, predictions = scores.max(dim = 1) # (n_samples)
correct_predictions = torch.eq(predictions, targets).sum().float()
accuracy = correct_predictions / targets.size(0)
return accuracy
class Metric:
"""Keep track of a single metric."""
def __init__(self, name: str) -> None:
self._name = name
self._data = np.array([])
def reset(self) -> None:
"""Clear all of the recorded data."""
self._data = np.array([])
def update(self, name: str, value: Union[Number, np.number]) -> None:
"""Record a new value."""
self._data = np.append(self._data, value)
@property
def mean(self) -> np.number:
"""Return the average value of the collected data."""
return self._data.mean()
@property
def std(self) -> np.number:
"""Return the std value of the collected data."""
return self._data.std()
@property
def max(self) -> np.number:
"""Return the maxinum value of the collected data."""
if self._data.shape[0] == 0:
return -np.inf
else:
return self._data.max()
@property
def min(self) -> np.number:
"""Return the minimum value of the collected data."""
if self._data.shape[0] == 0:
return np.inf
else:
return self._data.min()
@property
def recent(self) -> np.number:
"""Return the recent recorded data."""
return self._data[-1]
[docs]class MetricTracker:
"""Keep track of metrics."""
def __init__(self, *names) -> None:
self._metrics = OrderedDict()
for name in names:
self.add(name)
[docs] def add(self, name: str) -> None:
"""Add a new metric."""
if name in self._metrics:
warnings.warn(
'The metric `{}` already exists in the tracking list. To avoid '
'duplication, this metric is ignored'.format(name),
UserWarning, stacklevel=2
)
else:
self._metrics[name] = Metric(name)
[docs] def reset(self) -> None:
"""Clear all of the recorded metrics."""
for name in self.metrics.keys():
self._metrics[name].reset()
[docs] def update(self, name: str, value: Union[Number, np.number]) -> None:
"""Update a new value to a specified metric."""
self._metrics[name].update(name, value)
@property
def metrics(self) -> Dict[str, Metric]:
"""Return a ``dict`` containing all metrics."""
return self._metrics
def __getitem__(self, index: str) -> Metric:
return self._metrics[index]