"""Implementation of different distance metrics."""
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Union
import numpy as np
import torch
from torch import Tensor, mean
from torch.nn import L1Loss, MSELoss
from ecgan.utils.miscellaneous import to_torch
logger = getLogger(__name__)
[docs]class DistanceMetric(ABC):
"""A base class for different distance metrics to inherit from."""
def __call__(
self,
point_1: Union[np.ndarray, Tensor],
point_2: Union[np.ndarray, Tensor],
) -> Tensor:
"""Call the calculate method."""
return self.calculate(point_1, point_2)
[docs] @abstractmethod
def calculate(
self,
point_1: Union[np.ndarray, Tensor],
point_2: Union[np.ndarray, Tensor],
) -> Tensor:
"""
Calculate the distance between two points (arrays of same size).
Args:
point_1: Some data with at least 1 dimension.
point_2: Some data with at least 1 dimension.
Returns:
The distance.
"""
raise NotImplementedError("Distance needs to implement the `calculate` method.")
@staticmethod
def _reduction(pairwise_distance: Tensor, reduction: str) -> Tensor:
if reduction == 'mean':
return torch.mean(pairwise_distance)
if reduction == 'sum':
return torch.sum(pairwise_distance)
return pairwise_distance
[docs]class MinkowskiDistance(DistanceMetric):
"""
Implementation of the Minkowski distance of two vectors.
p=1: Manhattan Distance, p=2: Euclidean distance. Default is p=3.
"""
def __init__(self, order: int = 3, reduction: str = 'none'):
self.order = order
self.reduction = reduction
[docs] def calculate(
self,
point_1: Union[np.ndarray, Tensor],
point_2: Union[np.ndarray, Tensor],
) -> Tensor:
"""
Calculate the Minkowski distance.
Args:
point_1: Coordinates of one point.
point_2: Coordinate of another point.
Returns:
The Minkowski distance of point_1 and point_2.
"""
point_1 = torch.from_numpy(point_1) if isinstance(point_1, np.ndarray) else point_1
point_2 = torch.from_numpy(point_2) if isinstance(point_2, np.ndarray) else point_2
pairwise_distance = (abs(point_1 - point_2) ** self.order) ** (1 / self.order)
return self._reduction(pairwise_distance, self.reduction)
[docs]class L1Distance(DistanceMetric):
"""Implementation of the :math:`L_1`-distance."""
def __init__(self, reduction: str = 'none'):
self.reduction = reduction
[docs] def calculate(
self,
point_1: Union[np.ndarray, Tensor],
point_2: Union[np.ndarray, Tensor],
) -> Tensor:
r"""
Return the average :math:`L_1` distance per sample in the batch.
The pairwise :math:`L_1` distance of any shape - usually :math:`(b \times c \times s)` or :math:`(b
\times c)` is calculated, reshaped to :math:`(b, -1)` and returned.
"""
return mean(
L1Loss(reduction=self.reduction)(point_1, point_2).view(point_1.shape[0], -1),
dim=1,
)
[docs]class L2Distance(DistanceMetric):
"""Implementation of the :math:`L_2`-distance."""
def __init__(self, reduction: str = 'none'):
self.reduction = reduction
[docs] def calculate(
self,
point_1: Union[np.ndarray, Tensor],
point_2: Union[np.ndarray, Tensor],
) -> Tensor:
r"""
Return the average :math:`L_2` distance per sample in the batch.
The pairwise :math:`L_2` distance of any shape - usually :math:`(b \times c \times s)` or :math:`(b
\times c)` is calculated, reshaped to :math:`(b, -1)` and returned.
"""
point_1 = torch.from_numpy(point_1) if isinstance(point_1, np.ndarray) else point_1
point_2 = torch.from_numpy(point_2) if isinstance(point_2, np.ndarray) else point_2
return mean(
MSELoss(reduction=self.reduction)(point_1, point_2).view(point_1.shape[0], -1),
dim=1,
)