Source code for ecgan.modules.base

"""Abstract base module for learning algorithms which use training and validation steps."""
from __future__ import annotations

from abc import abstractmethod
from logging import getLogger
from typing import List

from ecgan.config import ModuleConfig, get_global_config
from ecgan.utils.artifacts import Artifact
from ecgan.utils.configurable import Configurable
from ecgan.utils.custom_types import SampleDataset
from ecgan.utils.datasets import DatasetFactory
from ecgan.utils.miscellaneous import select_device
from ecgan.utils.sampler import DataSampler
from ecgan.visualization.plotter import PlotterFactory

logger = getLogger(__name__)


[docs]class BaseModule(Configurable): """Base class from which all implemented modules should inherit.""" def __init__( self, cfg: ModuleConfig, seq_len: int, num_channels: int, ): self.cfg = cfg exp_cfg = get_global_config().experiment_config trainer_cfg = get_global_config().trainer_config self.dataset = DatasetFactory()(exp_cfg.DATASET) self.plotter = PlotterFactory.from_config(trainer_cfg) self.seq_len = seq_len self.num_channels = num_channels self.device = select_device(gpu_flag=exp_cfg.TRAIN_ON_GPU) logger.info('Using device {0}.'.format(self.device)) self.train_dataset_sampler: DataSampler = DataSampler( None, self.device, num_channels, seq_len, name=SampleDataset.TRAIN.value ) self.vali_dataset_sampler: DataSampler = DataSampler( None, self.device, num_channels, seq_len, name=SampleDataset.VALI.value )
[docs] @abstractmethod def training_step( self, batch: dict, ) -> dict: """ Declare what the model should do during a training step using a given batch. The returned metrics are concatenated across batches and **averaged** before logging. Note: this is important if you want to log min/max values! Args: batch: A batch of data. Return: A dict containing the metrics from optimization or evaluation which shall be logged. """ raise NotImplementedError("BaseModule needs to implement the `training_step` method.")
[docs] @abstractmethod def validation_step( self, batch: dict, ) -> dict: """ Declare what the model should do during a validation step. Args: batch: A batch of data. Return: A dict containing the metrics from optimization or evaluation which shall be logged. """ raise NotImplementedError("BaseModule needs to implement the `validation_step` method.")
[docs] @abstractmethod def save_checkpoint(self) -> dict: """Return current model parameters.""" raise NotImplementedError("BaseModule needs to implement the `save_checkpoint` method.")
[docs] @abstractmethod def load(self, model_reference: str, load_optim: bool = False) -> BaseModule: """Load a trained module from disk (file path) or wand reference.""" raise NotImplementedError("BaseModule needs to implement the `load` method.")
@property @abstractmethod def watch_list(self) -> List: """Return torch nn.Modules that should be watched during training.""" pass @staticmethod def _print_metric(epoch: int, metrics: dict) -> None: """Print formatted metrics that were produced by the model.""" formatted_metrics: str = 'Ep. {0} -> '.format(epoch) for key, value in metrics.items(): formatted_metrics += '{0}: {1:.4f} | '.format(key, value) logger.info(formatted_metrics)
[docs] @classmethod def print_metric(cls, epoch, metrics) -> None: """Allow overwriting the static `_print_metric` method.""" cls._print_metric(epoch, metrics)
[docs] @abstractmethod def on_epoch_end(self, epoch: int, sample_interval: int, batch_size: int) -> List[Artifact]: """ Set actions to be executed after epoch ends. Declare what should be done upon finishing an epoch (e.g. save artifacts or evaluate some metric). Args: epoch: Current training epoch. sample_interval: Regular sampling interval to save modules independent of their performance. batch_size: Size of batch. Returns: List containing all :class:`ecgan.utils.artifacts.Artifact` s which shall be logged upon epoch end. """ raise NotImplementedError("Module needs to implement the `on_epoch_end` method.")