"""Manages initialization, preprocessing, training and anomaly detection of ecgan."""
import os
from ecgan.anomaly_detection.anomaly_manager import AnomalyManager
from ecgan.config import (
ExperimentConfig,
GANModuleConfig,
TrainConfig,
TrainerConfig,
get_global_ad_config,
get_global_config,
get_run_config,
set_global_ad_config,
set_global_config,
set_global_inv_config,
)
from ecgan.config.initialization.anomaly_detection import init_detection
from ecgan.config.initialization.inverse import init_inverse
from ecgan.evaluation.tracker import TrackerFactory
from ecgan.modules.factory import ModuleFactory
from ecgan.modules.inverse_mapping.inversion import inverse_train
from ecgan.preprocessing.data_retrieval import DataRetrieverFactory, retrieve_fold_from_existing_split
from ecgan.preprocessing.preprocessor import PreprocessorFactory
from ecgan.training.trainer import Trainer
from ecgan.utils.datasets import DatasetFactory, SineDataset
from ecgan.utils.log import set_log_level, setup_logger
from ecgan.utils.miscellaneous import load_pickle_numpy, merge_dicts, update_dicts
from ecgan.utils.parser import config_parser, detection_parser, init_parser, inverse_parser
logger = setup_logger('ecgan')
[docs]def init():
"""Run initialization for ecgan and generate a config file."""
data_location, dataset, entity, module, name, output_file, project = init_parser()
config = {}
# Each returned dict contains the configuration for the respective component.
# Additionally, the 'update' key can be used to update other components.
experiment_config = ExperimentConfig.configure(
project=project,
experiment_name=name,
module=module,
dataset=dataset,
entity=entity,
)
preprocessing_config = DataRetrieverFactory.choose_class(dataset).configure()
trainer_config = TrainerConfig.configure(channels=DatasetFactory()(dataset).num_channels)
module_config = ModuleFactory.choose_class(module_name=module).configure()
preprocessing_config['preprocessing']['LOADING_DIR'] = data_location
update_config = {}
for cfg in [experiment_config, preprocessing_config, trainer_config, module_config]:
update_dicts(config, update_config, cfg)
merge_dicts(config, update_config)
configuration = TrainConfig(base_config=config, output_file=output_file)
configuration.generate_config_file()
[docs]def preprocess():
"""
Start preprocessing a dataset with parameters defined in the config file.
Requires a config file generated by `ecgan-init`. Raw data will be downloaded to the provided data path to
'<data_path>/<dataset_name>/raw' in an arbitrary format given by the different datasets. The data will then be
preprocessed and saved to '<data_path>/<dataset_name>/processed' in a format suitable for all other ecgan
components.
"""
filename = config_parser()
set_global_config(filename)
preprocess_config = get_global_config().preprocessing_config
exp_config = get_global_config().experiment_config
DataRetrieverFactory()(dataset=exp_config.DATASET, cfg=preprocess_config).load()
logger.info('Executing preprocessing procedure for dataset {0}.'.format(exp_config.DATASET))
if exp_config.DATASET == SineDataset.name:
logger.info('Sine dataset is already in matching format. Skipping preprocessing.')
else:
preprocessor = PreprocessorFactory().__call__(preprocess_config, dataset=exp_config.DATASET)
_, _ = preprocessor.preprocess()
preprocessor.save()
[docs]def train():
"""
Start a training run.
Requires a config file generated by `ecgan-init`.
Requires data preprocessed by `ecgan-preprocess`.
"""
filename = config_parser()
set_global_config(filename)
exp_config = get_global_config().experiment_config
preprocessing_cfg = get_global_config().preprocessing_config
set_log_level(exp_config.TRACKER.LOG_LEVEL)
target_dir = os.path.join(preprocessing_cfg.LOADING_DIR, exp_config.DATASET, 'processed')
data = load_pickle_numpy(os.path.join(target_dir, 'data.pkl'))
label = load_pickle_numpy(os.path.join(target_dir, 'label.pkl'))
trainer = Trainer(data=data, label=label)
trainer.fit()
[docs]def inverse():
"""
Start an inverse mapping run for a trained model.
The inverse mapping requires an additional configuration. This configuration has to be generated using the
init (`-i`) flag before the run can be started.
"""
args = inverse_parser()
filename = args.out
if args.init:
init_inverse(path=args.init, filename=filename)
return
set_global_inv_config(filename)
inverse_train()
[docs]def detect():
"""
Start the anomaly detection process.
Requires generating an anomaly detection config file first using the `-i`flag. `entity`, `project` and `run_name`
are mandatory during init to create the config file and not permitted as CLI parameters afterwards.
Requires data preprocessed by `ecgan-preprocess`.
Requires trained model from `ecgan-train`.
"""
args = detection_parser()
filename = args.config
if args.init:
init_detection(args)
return
set_global_ad_config(filename)
ad_config = get_global_ad_config()
config = get_run_config(ad_config)
set_global_config(config.config_dict)
trainer_cfg = get_global_config().trainer_config
exp_cfg = get_global_config().experiment_config
set_log_level(exp_cfg.TRACKER.LOG_LEVEL)
tracker = TrackerFactory()(config=ad_config.ad_experiment_config)
logger.info(
'Starting anomaly detection process with dataset {0} and model stored in {1}.'.format(
config.experiment_config.DATASET, ad_config.ad_experiment_config.RUN_URI
)
)
data_source_dir = os.path.join(exp_cfg.LOADING_DIR, exp_cfg.DATASET, 'processed')
(train_x, test_x, vali_x, train_y, test_y, vali_y,) = retrieve_fold_from_existing_split(
data_dir=data_source_dir,
split_path=ad_config.ad_experiment_config.RUN_URI,
split_file=trainer_cfg.SPLIT_PATH,
fold=ad_config.ad_experiment_config.FOLD,
location=exp_cfg.TRACKER.tracker_name,
target_dir=tracker.run_dir,
)
module_cfg = get_global_config().module_config
rescale_to_unit_circle = False
if isinstance(module_cfg, GANModuleConfig):
rescale_to_unit_circle = module_cfg.GENERATOR.TANH_OUT
train_x, test_x, vali_x, train_y, test_y, vali_y = Trainer.prepare_data(
train_x,
test_x,
vali_x,
train_y,
test_y,
vali_y,
trainer_cfg,
rescale_to_unit_circle,
)
anomaly_detection_manager = AnomalyManager(
ad_config,
seq_len=train_x.shape[1],
tracker=tracker,
num_channels=train_x.shape[2],
)
anomaly_detection_manager.start_detection(
train_x=train_x,
train_y=train_y,
vali_x=vali_x,
vali_y=vali_y,
test_x=test_x,
test_y=test_y,
)