Source code for ecgan.anomaly_detection.embedder

"""Embedding logic for anomaly detection."""
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from sklearn.base import BaseEstimator
from torch import Tensor

from ecgan.config import get_global_config
from ecgan.utils.datasets import DatasetFactory
from ecgan.utils.embeddings import assert_and_reshape_dim, calculate_umap
from ecgan.utils.miscellaneous import to_numpy
from ecgan.visualization.plotter import ScatterPlotter

logger = getLogger(__name__)


[docs]class Embedder(ABC): """ Abstract embedding class. Args: train_x: Data the initial embedding should be fit on. train_y: Labels the initial embedding should be fit on. dataset: Dataset identifier, has to be supported by :class:`ecgan.utils.datasets.DatasetFactory` """ def __init__( self, train_x: Union[Tensor, np.ndarray], train_y: Union[Tensor, np.ndarray], dataset: str, ): self._train_x = to_numpy(train_x) self._train_y = to_numpy(train_y) self.classes = list( DatasetFactory()(dataset).beat_types_binary.keys() if get_global_config().trainer_config.BINARY_LABELS else list(DatasetFactory()(dataset).beat_types.keys()) ) self.dataset = dataset self.plotter = ScatterPlotter() self._initial_embedding: Optional[np.ndarray] = None self._reducer: Optional[BaseEstimator] = None
[docs] @abstractmethod def get_initial_embedding(self) -> Tuple[np.ndarray, BaseEstimator]: """ Create an initial embedding and train a reducer. Returns: The embedding data as well as the reducer. """ raise NotImplementedError("Embedder needs to implement the `get_initial_embedding` method.")
[docs] def embed_test( self, test_x: Union[Tensor, np.ndarray], test_y: Union[Tensor, np.ndarray], include_initial_embedding: bool, ) -> Tuple[np.ndarray, np.ndarray]: """ Embed test data into embedding pre-trained on the training data. Args: test_x: Test data. test_y: Test labels. include_initial_embedding: Flag to indicate if the embedding trained on train_x should be returned as well. Returns: (embedding, labels) tuple. The labels have shifted to avoid conflicts with existing train labels. """ data = to_numpy(test_x) labels = to_numpy(test_y) return self._embed_test(data, labels, include_initial_embedding)
@abstractmethod def _embed_test( self, test_x: np.ndarray, test_y: np.ndarray, include_initial_embedding: bool ) -> Tuple[np.ndarray, np.ndarray]: """ Embed test data into embedding pre-trained on the training data. Args: test_x: Test data. test_y: Test labels. include_initial_embedding: Flag to indicate if the embedding trained on train_x should be returned as well. Returns: (embedding, labels) tuple. The labels have shifted to avoid conflicts with existing train labels. """ raise NotImplementedError("Embedder needs to implement the `_embed_test` method.")
[docs] def embed(self, embeddable_data: Tensor, labels: Optional[Tensor] = None) -> Tuple[np.ndarray, np.ndarray]: """ Embed novel data into given embedding. This data can be synthetic data. Args: embeddable_data: Any data which can be embedded into a trained embedding. labels: Labels corresponding to the embeddable data. Returns: (embedding, labels) tuple. The labels have shifted to avoid conflicts with existing train labels. """ data_ = to_numpy(embeddable_data) labels_ = to_numpy(labels) if labels is not None else None return self._embed(data_, labels_)
@abstractmethod def _embed(self, embeddable_data: np.ndarray, labels: Optional[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: """Embed the novel data into given train/test sets.""" raise NotImplementedError("Embedder needs to implement the `_embed` method.")
[docs] @abstractmethod def get_plot(self, embedding: np.ndarray, labels: np.ndarray) -> plt.Figure: """Retrieve a plot of the embedding as an np.ndarray.""" raise NotImplementedError("Embedder needs to implement the `get_plot` method.")
[docs] @abstractmethod def draw_interpolation_path( self, trace: Union[np.ndarray, Tensor], labels: Union[np.ndarray, Tensor], embedding: Optional[Union[np.ndarray, Tensor]] = None, fig_size: Tuple[float, float] = (15, 11.25), ) -> plt.Figure: """Draw an interpolation path of reconstructed data into an existing embedding.""" raise NotImplementedError("Embedder needs to implement the `draw_interpolation_path` method.")
@property def reducer(self) -> BaseEstimator: if self._reducer is None: _, reducer = self.get_initial_embedding() self._reducer = reducer return self._reducer @property def initial_embedding(self) -> np.ndarray: if self._initial_embedding is None: embedding, _ = self.get_initial_embedding() self._initial_embedding = embedding return self._initial_embedding @property def initial_labels(self) -> np.ndarray: return self._train_y
[docs]class UMAPEmbedder(Embedder): """Create UMAP embedding (on train/vali data) and allow embedding of novel (test) data.""" def __init__( self, train_x: Tensor, train_y: Tensor, dataset, ): super().__init__(train_x, train_y, dataset) self.total_classes = self.classes self.len_fit_labels: int = len(self.classes)
[docs] def get_initial_embedding(self) -> Tuple[np.ndarray, BaseEstimator]: """Create an initial embedding.""" logger.info('Creating initial UMAP embedding...') initial_embedding, reducer = calculate_umap(self._train_x, self.initial_labels, supervised_umap=True) logger.info('Created initial UMAP embedding.') # Store the reducer for future inference of the remaining data. # Can be used to avoid recalculation each run since the embeddings will usually be similar. self._reducer = reducer self._initial_embedding = initial_embedding return initial_embedding, reducer
def _embed_test( self, test_x: np.ndarray, test_y: np.ndarray, include_initial_embedding: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """Embed remaining test data using a reducer fit onto the initial embedding.""" if self.reducer is None or self.initial_embedding is None: self.get_initial_embedding() test_data = assert_and_reshape_dim(test_x) embedding_test: np.ndarray = self.reducer.transform(test_data) # type: ignore # Shift all labels to avoid label conflicts test_labels: np.ndarray = test_y + self.len_fit_labels self.total_classes = self.total_classes + [classname + '_test' for classname in self.classes] if not include_initial_embedding: return embedding_test, test_labels embedding = np.concatenate((self.initial_embedding, embedding_test)) labels = np.concatenate((self._train_y, test_labels)) return embedding, labels def _embed(self, embeddable_data: np.ndarray, labels: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: """ Create UMAP embedding of the real and synthetic data. Create initial embedding and reducer if this is the first run. The initial embedding contains all train data and a subset of the test data. This embedding is reused in subsequent runs to reduce the computational cost. """ if self.reducer is None or self.initial_embedding is None: self.get_initial_embedding() embedding = self.reducer.transform(assert_and_reshape_dim(embeddable_data)) # type: ignore self.total_classes = self.total_classes + [classname + '_embedded' for classname in self.classes] if labels is None: len_interpolated_labels = 2 * self.len_fit_labels updated_labels = np.full((embeddable_data.shape[0],), len_interpolated_labels) else: updated_labels = labels + 2 * self.len_fit_labels return embedding, updated_labels
[docs] def get_plot(self, embedding: np.ndarray, labels: np.ndarray) -> plt.Figure: """Retrieve a plot of the embedding as a matplotlib Figure.""" return self.plotter.plot_scatter( data=embedding, target=labels, fig_title='', classes=self.total_classes, )
[docs] def draw_interpolation_path( self, trace: Union[np.ndarray, Tensor], labels: Union[np.ndarray, Tensor], embedding: Optional[Union[np.ndarray, Tensor]] = None, fig_size: Tuple[float, float] = (15, 11.25), ) -> plt.Figure: """ Plot a trace/path in latent space between data in the trace tensor. Args: trace: The trace to draw. labels: The data labels. embedding: The embedding the trace should be embedded into. fig_size: Size of the matplotlib figure (width, height). Returns: A matplotlib Figure of the visualized embedding. """ embedding_np = to_numpy(embedding) if embedding is not None else self.initial_embedding labels_np = to_numpy(labels) trace_np: np.ndarray = to_numpy(trace) embedding_trace = self.reducer.transform(assert_and_reshape_dim(trace_np)) return self.plotter.plot_interpolation_path( data=embedding_np, labels=labels_np, trace=embedding_trace, fig_size=fig_size, classes=self.total_classes, )