Source code for ecgan.anomaly_detection.detector.detector_factory

"""Factory to return AnomalyDetector objects."""
from ecgan.anomaly_detection.detector.base_detector import AnomalyDetector
from ecgan.anomaly_detection.detector.classification_detector import ArgmaxClassifierDetector
from ecgan.anomaly_detection.detector.reconstruction_detector import GANAnomalyDetector, GANInverseAnomalyDetector
from ecgan.evaluation.tracker import BaseTracker
from ecgan.modules.base import BaseModule
from ecgan.modules.classifiers.nn_classifier import NNClassifier
from ecgan.modules.generative.base import BaseGANModule
from ecgan.utils.custom_types import AnomalyDetectionStrategies, ReconstructionType


[docs]class AnomalyDetectorFactory: """Meta module for creating correct anomaly detectors."""
[docs] @staticmethod def choose_class(module: AnomalyDetectionStrategies): """Choose the correct class based on the provided module name.""" anomaly_detectors = { AnomalyDetectionStrategies.ANOGAN: GANAnomalyDetector, AnomalyDetectionStrategies.ARGMAX: ArgmaxClassifierDetector, AnomalyDetectionStrategies.INVERSE_MAPPING: GANInverseAnomalyDetector, } try: return anomaly_detectors[module] except KeyError as err: raise AttributeError('Argument {0} is not set correctly.'.format(module)) from err
[docs] def __call__(self, detector: str, module: BaseModule, tracker: BaseTracker) -> AnomalyDetector: """Return implemented AD module when a BaseModule is created.""" if detector == AnomalyDetectionStrategies.ANOGAN.value: if not isinstance(module, BaseGANModule): raise TypeError( 'BaseGANModule is expected for AnoGAN detection, current model type: {0}'.format(type(module)) ) return GANAnomalyDetector( module=module, reconstructor=ReconstructionType.INTERPOLATE, tracker=tracker, ) if detector == AnomalyDetectionStrategies.INVERSE_MAPPING.value: if not isinstance(module, BaseGANModule): raise TypeError( 'BaseGANModule is expected for inverse mapping, current module type: {0}'.format(type(module)) ) return GANInverseAnomalyDetector( module=module, reconstructor=ReconstructionType.INVERSE_MAPPING, tracker=tracker, ) if detector == AnomalyDetectionStrategies.ARGMAX.value: if not isinstance(module, NNClassifier): raise TypeError('NNClassifier is expected, current type: {0}'.format(type(module))) return ArgmaxClassifierDetector(module=module, tracker=tracker) raise AttributeError('Argument {0} is not set correctly.'.format(detector))