Source code for ecgan.training.trainer

"""Basic trainer class used to create data splits, initialize training modules, fit a model and collect metrics."""
from logging import getLogger
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from ecgan.config import TrainerConfig, get_global_config
from ecgan.evaluation.tracker import TrackerFactory
from ecgan.modules.base import BaseModule
from ecgan.modules.factory import ModuleFactory
from ecgan.modules.generative.base import BaseGANModule
from ecgan.training.datasets import SeriesDataset
from ecgan.utils.artifacts import FileArtifact, ImageArtifact
from ecgan.utils.custom_types import SplitMethods, Transformation, save_epoch_metrics
from ecgan.utils.miscellaneous import (
    list_from_tuple_list,
    load_pickle,
    nested_list_from_dict_list,
    save_epoch,
    scale_to_unit_circle,
    to_torch,
    update_highest_metrics,
)
from ecgan.utils.splitting import create_splits, load_split, select_channels, verbose_channel_selection
from ecgan.utils.transformation import get_transformation
from ecgan.visualization.evaluation import boxplot
from ecgan.visualization.plotter import FourierPlotter, PlotterFactory

logger = getLogger(__name__)


[docs]class Trainer: """ Load all required elements (data, config, tracking, plotter) and initialize the model. Requires a previously set config (via :func:`ecgan.config.global_cfg.set_global_config`). """ def __init__( self, data: Union[Tensor, np.ndarray], label: Union[Tensor, np.ndarray], ): self.trainer_cfg = get_global_config().trainer_config self.exp_cfg = get_global_config().experiment_config self.module_cfg = get_global_config().module_config torch.manual_seed(self.trainer_cfg.MANUAL_SEED) self.tracker = TrackerFactory()(config=self.exp_cfg) logger.info( 'Starting training process with dataset {0} using model {1}.'.format( self.exp_cfg.DATASET, self.exp_cfg.MODULE ) ) self.data = to_torch(data).float() self.label = to_torch(label).float() plotter = PlotterFactory.from_config(self.trainer_cfg) num_channels = ( self.trainer_cfg.CHANNELS if isinstance(self.trainer_cfg.CHANNELS, int) else len(self.trainer_cfg.CHANNELS) ) # If fourier transform is chosen: double the amount of channels (img+real) channel_multiplier = 2 if isinstance(plotter, FourierPlotter) else 1 self.num_channels = channel_multiplier * num_channels module_factory = ModuleFactory() self.module = module_factory( cfg=self.module_cfg, module=self.exp_cfg.MODULE, seq_len=data.shape[1], num_channels=self.num_channels, ) self.tracker.log_config(get_global_config().config_dict) self.tracker.watch(self.module.watch_list) self.tracker.plotter = plotter
[docs] def fit(self) -> BaseModule: """Fit a model on a split.""" split_indices = self.get_split() self._fit_cross_val(split_indices) self.tracker.close() return self.module
[docs] def get_split(self) -> Dict: """Retrieve existing split or create new split with given amount of folds.""" if self.trainer_cfg.SPLIT_METHOD == 'fixed': split_indices: Dict = load_pickle(self.trainer_cfg.SPLIT_PATH) else: split_method = SplitMethods.NORMAL_ONLY if self.trainer_cfg.TRAIN_ONLY_NORMAL else SplitMethods.MIXED split_indices = create_splits( self.data, self.label, seed=self.trainer_cfg.MANUAL_SEED, folds=self.trainer_cfg.CROSS_VAL_FOLDS, method=split_method, split=self.trainer_cfg.SPLIT, ) self.tracker.log_artifacts(FileArtifact('split indices', split_indices, 'split.pkl')) return split_indices
[docs] @staticmethod def prepare_data( train_x, test_x, vali_x, train_y, test_y, vali_y, trainer_cfg: TrainerConfig, rescale_to_unit_circle: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Select channels, transform data, mask labels and possibly rescale data based on cfg. Args: train_x: Train data. test_x: Test data. vali_x: Validation data. train_y: Train labels. test_y: Test labels. vali_y: Validation labels. trainer_cfg: Trainer configuration. rescale_to_unit_circle: Flag indicating if the data shall be rescaled to unit circle. Returns: Transformed train, test and vali data and labels. """ ############################################################ # TRANSFORM DATA ############################################################ verbose_channel_selection(test_x, trainer_cfg.CHANNELS) train_x = select_channels(train_x, trainer_cfg.CHANNELS) vali_x = select_channels(vali_x, trainer_cfg.CHANNELS) test_x = select_channels(test_x, trainer_cfg.CHANNELS) if trainer_cfg.TRANSFORMATION is not Transformation.NONE.value: logger.info("Applying data transformation: {0}.".format(trainer_cfg.TRANSFORMATION)) transformer = get_transformation(trainer_cfg.transformation) train_x = transformer.fit_transform(train_x) test_x = transformer.transform(test_x) vali_x = transformer.transform(vali_x) if trainer_cfg.BINARY_LABELS: train_y[train_y != 0] = 1 vali_y[vali_y != 0] = 1 test_y[test_y != 0] = 1 if rescale_to_unit_circle: logger.info("Rescaling data to unit circle (range [-1,1]).") train_x = scale_to_unit_circle(train_x) test_x = scale_to_unit_circle(test_x) vali_x = scale_to_unit_circle(vali_x) return train_x, test_x, vali_x, train_y, test_y, vali_y
def _fit_split(self, train_x, vali_x, train_y, vali_y) -> Dict: """ Fit training data using the model the trainer was initialized with. Args: train_x: Training data. vali_x: Validation data. train_y: Training labels. vali_y: Validation labels. Returns: Dictionary containing the highest metrics measured during training. The keys are metric names, values are (epoch, value) tuples. """ ############################################################ # CREATE DATASETS & DATA LOADERS ############################################################ train_data = SeriesDataset(train_x, train_y) validation_data = SeriesDataset(vali_x, vali_y) self.module.train_dataset_sampler.set_dataset(train_data) self.module.vali_dataset_sampler.set_dataset(validation_data) train_loader = DataLoader( dataset=train_data, batch_size=self.trainer_cfg.BATCH_SIZE, shuffle=True, num_workers=self.trainer_cfg.NUM_WORKERS, pin_memory=True, ) validation_loader = DataLoader( dataset=validation_data, shuffle=True, batch_size=self.trainer_cfg.BATCH_SIZE, num_workers=self.trainer_cfg.NUM_WORKERS, pin_memory=True, ) ############################################################ # CONDUCT TRAINING ############################################################ logger.info('Start training with {0} epochs'.format(self.trainer_cfg.EPOCHS)) # The highest metrics are (epoch, value) pairs. highest_metrics: Dict = {} for epoch in range(1, self.trainer_cfg.EPOCHS + 1): train_metrics: List[Dict] = [] validation_metrics: List[Dict] = [] # TRAINING LOOP for _batch_idx, batch in enumerate(tqdm(train_loader, leave=False, desc='Training epoch {}'.format(epoch))): metrics = self.module.training_step(batch) train_metrics.append(metrics) # VALIDATION LOOP for _batch_idx, batch in enumerate( tqdm(validation_loader, leave=False, desc='Validation epoch {}'.format(epoch)) ): validation_results: Dict = self.module.validation_step(batch) if validation_results is not None or not validation_results: validation_metrics.append(validation_results) # AFTER EPOCH ACTION artifacts = self.module.on_epoch_end(epoch, self.trainer_cfg.SAMPLE_INTERVAL, self.trainer_cfg.BATCH_SIZE) # HANDLE EPOCH ARTIFACTS AND TRACKING collated_train_metrics = self.tracker.collate_metrics(train_metrics) collated_validation_metrics = self.tracker.collate_metrics(validation_metrics) self.tracker.log_metrics(collated_train_metrics) self.tracker.log_metrics(collated_validation_metrics) self.tracker.log_artifacts(artifacts) self.tracker.advance_step() self.module.print_metric(epoch, collated_train_metrics) # CHECKPOINT highest_metrics = update_highest_metrics( collated_validation_metrics, artifacts, highest_metrics, epoch, ) if save_epoch( highest_metrics, epoch, self.trainer_cfg.CHECKPOINT_INTERVAL, save_epoch_metrics, final_epoch=self.trainer_cfg.EPOCHS, ): fold = self.tracker.fold self.tracker.log_checkpoint(self.module.save_checkpoint(), fold) return highest_metrics def _fit_cross_val(self, split_indices: Dict): """Conduct training with cross validation.""" metric_tuple_list: List[dict] = [] rescale_to_unit_circle = False if isinstance(self.module, BaseGANModule): if self.module.cfg.GENERATOR.TANH_OUT: rescale_to_unit_circle = True for fold in range(1, self.trainer_cfg.CROSS_VAL_FOLDS + 1): train_x, test_x, vali_x, train_y, test_y, vali_y = load_split( self.data, self.label, index_dict=split_indices, fold=fold ) train_x, _test_x, vali_x, train_y, _test_y, vali_y = self.prepare_data( train_x, test_x, vali_x, train_y, test_y, vali_y, self.trainer_cfg, rescale_to_unit_circle, ) metrics = self._fit_split(train_x, vali_x, train_y, vali_y) metric_tuple_list.append(metrics) # Advance fold and create new module. if fold < self.trainer_cfg.CROSS_VAL_FOLDS: self.tracker.advance_fold() module_factory = ModuleFactory() self.module = module_factory( cfg=self.module_cfg, module=self.exp_cfg.MODULE, seq_len=vali_x.shape[1], num_channels=self.num_channels, ) metric_dict_list = list_from_tuple_list(metric_tuple_list) self.tracker.log_artifacts(FileArtifact('metrics across all folds', split_indices, 'metrics.json')) metric_list, name_list = nested_list_from_dict_list(metric_dict_list) collated = self.tracker.collate_metrics(metric_dict_list) # Create a boxplot for all metrics self.tracker.log_artifacts( ImageArtifact( 'Metric Boxplot Train', boxplot( metric_list, name_list, "Metrics across {0} folds.".format(self.trainer_cfg.CROSS_VAL_FOLDS), ), ) ) logger.info("Collated metrics across folds: {0}.".format(collated))