Source code for ecgan.anomaly_detection.detector.base_detector

"""Base class used for anomaly detection."""
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Dict, Union

import numpy as np
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from ecgan.config import get_global_ad_config
from ecgan.evaluation.tracker import BaseTracker
from ecgan.modules.base import BaseModule
from ecgan.modules.classifiers.base import BaseClassifier
from ecgan.training.datasets import SeriesDataset
from ecgan.utils.artifacts import FileArtifact
from ecgan.utils.configurable import Configurable

logger = getLogger(__name__)


[docs]class AnomalyDetector(Configurable, ABC): """ A baseclass for various (PyTorch based) anomaly detectors. This class can be used to implement general anomaly detection algorithms and it is not limited to deep learning/ machine learning approaches. Each :code:`AnomalyDetector` should be able to #. Assert labels to a given time series. #. Save/load relevant evaluation data to/from a pkl file. This includes at least the labels but can be expanded for arbitrary information such as anomaly scores or data obtained from reconstructions. The labeling can depend on the type of detection (e.g. one score per series, channel or point). Thus, no specific format will be enforced. In general, we will save the labels per series and if scores are used to determine if some point is anomalous, we use pointwise scoring whenever possible, since channel-/serieswise anomalies can usually be reconstructed based on that data. The actual performance measures are controlled by the AnomalyManager based on the predicted labels. """
[docs] def __init__(self, module: Union[BaseModule, BaseClassifier], tracker: BaseTracker): self.module = module self.cfg = get_global_ad_config() self.tracker = tracker
[docs] @abstractmethod def _detect(self, data: Tensor) -> Tensor: """ Detect anomalies based on the desired detection scheme and return the asserted class labels. Args: data: Tensor (usually of size [batch_size, seq_len, channels]) of data which shall be classified. Returns: A Tensor with the label predictions for `data`. """ raise NotImplementedError("AnomalyDetector needs to implement the `_detect` method.")
[docs] def detect(self, test_x: Tensor, test_y: Tensor) -> np.ndarray: """ Detect anomalies based on the desired detection scheme and return the asserted class labels. Data is expected to be shuffled when passed to the detect method. It is then fed into a DataLoader, chunked into batches on which anomalies are detected. The function calls the abstract `._detect` method and logs wall time. Args: test_x: The shuffled test data. test_y: The labels corresponding to test_y. Returns: Predicted labels. """ predicted_labels = np.empty(test_y.shape) test_dataset = SeriesDataset( test_x.float(), test_y.float(), ) dataloader = DataLoader( dataset=test_dataset, batch_size=self.cfg.detection_config.BATCH_SIZE, num_workers=self.cfg.detection_config.NUM_WORKERS, pin_memory=True, ) batch_size = dataloader.batch_size if dataloader.batch_size is not None else 0 for batch_idx, batch in enumerate(tqdm(dataloader, leave=False)): data = batch['data'].to(self.module.device) if isinstance(self.module, BaseModule) else batch['data'] predicted_batch_labels = self._detect(data=data) batch_start_idx = batch_idx * batch_size batch_end_idx = (1 + batch_idx) * batch_size predicted_labels[batch_start_idx:batch_end_idx] = predicted_batch_labels.numpy() return predicted_labels
[docs] @abstractmethod def _get_data_to_save(self) -> Dict: """Select list of objects which shall be saved using the tracker.""" raise NotImplementedError("AnomalyDetector needs to implement the `_get_data_to_save` method.")
[docs] def save(self, run_id: str) -> None: """Save anomaly detection results to tracker.""" self.tracker.log_artifacts( FileArtifact( 'Anomaly Detection Data', self._get_data_to_save(), 'detection_data_{}.pkl'.format(run_id), ) )
[docs] @abstractmethod def load(self, saved_data: Dict) -> None: """ Load AD data from dict. The provided dict is usually part of an output of a previous AD run. The `load` method loads the saved data to the instantiated detector and can subsequently used in to e.g. create embeddings without reprocessing all data. The user is tasked to retrieve the saved dict by themself. Args: saved_data: Previously saved data, loaded into variables of the respective detector. """ raise NotImplementedError("AnomalyDetector needs to implement the `load` method.")