Source code for ecgan.manager

"""Manages initialization, preprocessing, training and anomaly detection of ecgan."""
import os

from ecgan.anomaly_detection.anomaly_manager import AnomalyManager
from ecgan.config import (
    ExperimentConfig,
    GANModuleConfig,
    TrainConfig,
    TrainerConfig,
    get_global_ad_config,
    get_global_config,
    get_run_config,
    set_global_ad_config,
    set_global_config,
    set_global_inv_config,
)
from ecgan.config.initialization.anomaly_detection import init_detection
from ecgan.config.initialization.inverse import init_inverse
from ecgan.evaluation.tracker import TrackerFactory
from ecgan.modules.factory import ModuleFactory
from ecgan.modules.inverse_mapping.inversion import inverse_train
from ecgan.preprocessing.data_retrieval import DataRetrieverFactory, retrieve_fold_from_existing_split
from ecgan.preprocessing.preprocessor import PreprocessorFactory
from ecgan.training.trainer import Trainer
from ecgan.utils.datasets import DatasetFactory, SineDataset
from ecgan.utils.log import set_log_level, setup_logger
from ecgan.utils.miscellaneous import load_pickle_numpy, merge_dicts, update_dicts
from ecgan.utils.parser import config_parser, detection_parser, init_parser, inverse_parser

logger = setup_logger('ecgan')


[docs]def init(): """Run initialization for ecgan and generate a config file.""" data_location, dataset, entity, module, name, output_file, project = init_parser() config = {} # Each returned dict contains the configuration for the respective component. # Additionally, the 'update' key can be used to update other components. experiment_config = ExperimentConfig.configure( project=project, experiment_name=name, module=module, dataset=dataset, entity=entity, ) preprocessing_config = DataRetrieverFactory.choose_class(dataset).configure() trainer_config = TrainerConfig.configure(channels=DatasetFactory()(dataset).num_channels) module_config = ModuleFactory.choose_class(module_name=module).configure() preprocessing_config['preprocessing']['LOADING_DIR'] = data_location update_config = {} for cfg in [experiment_config, preprocessing_config, trainer_config, module_config]: update_dicts(config, update_config, cfg) merge_dicts(config, update_config) configuration = TrainConfig(base_config=config, output_file=output_file) configuration.generate_config_file()
[docs]def preprocess(): """ Start preprocessing a dataset with parameters defined in the config file. Requires a config file generated by `ecgan-init`. Raw data will be downloaded to the provided data path to '<data_path>/<dataset_name>/raw' in an arbitrary format given by the different datasets. The data will then be preprocessed and saved to '<data_path>/<dataset_name>/processed' in a format suitable for all other ecgan components. """ filename = config_parser() set_global_config(filename) preprocess_config = get_global_config().preprocessing_config exp_config = get_global_config().experiment_config DataRetrieverFactory()(dataset=exp_config.DATASET, cfg=preprocess_config).load() logger.info('Executing preprocessing procedure for dataset {0}.'.format(exp_config.DATASET)) if exp_config.DATASET == SineDataset.name: logger.info('Sine dataset is already in matching format. Skipping preprocessing.') else: preprocessor = PreprocessorFactory().__call__(preprocess_config, dataset=exp_config.DATASET) _, _ = preprocessor.preprocess() preprocessor.save()
[docs]def train(): """ Start a training run. Requires a config file generated by `ecgan-init`. Requires data preprocessed by `ecgan-preprocess`. """ filename = config_parser() set_global_config(filename) exp_config = get_global_config().experiment_config preprocessing_cfg = get_global_config().preprocessing_config set_log_level(exp_config.TRACKER.LOG_LEVEL) target_dir = os.path.join(preprocessing_cfg.LOADING_DIR, exp_config.DATASET, 'processed') data = load_pickle_numpy(os.path.join(target_dir, 'data.pkl')) label = load_pickle_numpy(os.path.join(target_dir, 'label.pkl')) trainer = Trainer(data=data, label=label) trainer.fit()
[docs]def inverse(): """ Start an inverse mapping run for a trained model. The inverse mapping requires an additional configuration. This configuration has to be generated using the init (`-i`) flag before the run can be started. """ args = inverse_parser() filename = args.out if args.init: init_inverse(path=args.init, filename=filename) return set_global_inv_config(filename) inverse_train()
[docs]def detect(): """ Start the anomaly detection process. Requires generating an anomaly detection config file first using the `-i`flag. `entity`, `project` and `run_name` are mandatory during init to create the config file and not permitted as CLI parameters afterwards. Requires data preprocessed by `ecgan-preprocess`. Requires trained model from `ecgan-train`. """ args = detection_parser() filename = args.config if args.init: init_detection(args) return set_global_ad_config(filename) ad_config = get_global_ad_config() config = get_run_config(ad_config) set_global_config(config.config_dict) trainer_cfg = get_global_config().trainer_config exp_cfg = get_global_config().experiment_config set_log_level(exp_cfg.TRACKER.LOG_LEVEL) tracker = TrackerFactory()(config=ad_config.ad_experiment_config) logger.info( 'Starting anomaly detection process with dataset {0} and model stored in {1}.'.format( config.experiment_config.DATASET, ad_config.ad_experiment_config.RUN_URI ) ) data_source_dir = os.path.join(exp_cfg.LOADING_DIR, exp_cfg.DATASET, 'processed') (train_x, test_x, vali_x, train_y, test_y, vali_y,) = retrieve_fold_from_existing_split( data_dir=data_source_dir, split_path=ad_config.ad_experiment_config.RUN_URI, split_file=trainer_cfg.SPLIT_PATH, fold=ad_config.ad_experiment_config.FOLD, location=exp_cfg.TRACKER.tracker_name, target_dir=tracker.run_dir, ) module_cfg = get_global_config().module_config rescale_to_unit_circle = False if isinstance(module_cfg, GANModuleConfig): rescale_to_unit_circle = module_cfg.GENERATOR.TANH_OUT train_x, test_x, vali_x, train_y, test_y, vali_y = Trainer.prepare_data( train_x, test_x, vali_x, train_y, test_y, vali_y, trainer_cfg, rescale_to_unit_circle, ) anomaly_detection_manager = AnomalyManager( ad_config, seq_len=train_x.shape[1], tracker=tracker, num_channels=train_x.shape[2], ) anomaly_detection_manager.start_detection( train_x=train_x, train_y=train_y, vali_x=vali_x, vali_y=vali_y, test_x=test_x, test_y=test_y, )