Source code for ecgan.modules.inverse_mapping.inverse_mapping

r"""
Defines a module that includes an inverse mapping.

In the typical ECGAN use-case the module would consist of two different mappings: :math:`G: A \rightarrow B` and
:math:`Inv: B \rightarrow A`, where :math:`G` is a typical generator that maps a given distribution (commonly
a normal distribution) :math:`A` to some set :math:`B`. The Inv function is the inverting mapping, essentially
tasked with restoring the distribution :math:`A` from a given sample of :math:`B`.
"""
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Dict, Optional

import torch
from torch import nn
from tqdm import tqdm

from ecgan.config import InverseConfig, ModuleConfig, get_global_config, get_global_inv_config
from ecgan.evaluation.tracker import BaseTracker, TrackerFactory
from ecgan.modules.base import BaseModule
from ecgan.utils.configurable import Configurable
from ecgan.utils.miscellaneous import load_model


[docs]class InvertibleBaseModule(BaseModule, Configurable): """ The abstract base class for inverse mappings. Every implementation of this class #. gets at least a reference to some trained generator module and #. must implement an inverse method (:code:`invert`) that restores the input data for the generator module. """ def __init__( self, inv_cfg: InverseConfig.Attribs, module_cfg: ModuleConfig, seq_len: int, num_channels: int, tracker: Optional[BaseTracker] = None, ): super().__init__(module_cfg, seq_len, num_channels) self._inv_config = inv_cfg self._generator_module = self._init_generator_module() self.inv = self._init_inv() self.inv = nn.DataParallel(self.inv) self.inv.to(self.device) self.exp_cfg = get_global_config().experiment_config if tracker is None: self.tracker = TrackerFactory()(config=self.exp_cfg) self.close_tracker = True else: self.tracker = tracker self.close_tracker = False self.tracker.log_config(get_global_inv_config().config_dict)
[docs] @abstractmethod def invert(self, data) -> torch.Tensor: """ Apply the inverse mapping for the provided data. Note that the function does not make any assumptions about the gradient and must be wrapped into a torch.no_grad if the gradient is not needed. """ raise NotImplementedError("InvertibleModule needs to implement the `invert` method.")
[docs] def training_step( self, batch: dict, ) -> dict: """ Perform a training step that can be called by a :class:`ecgan.training.trainer.Trainer` class instance. Note that the function does not require training data and the batch is only required to figure out the appropriate batch size. """ return self._training_step(batch['data'].shape[0])
@abstractmethod def _training_step( self, batch_size: int, ) -> dict: """ Perform a training step. This private method is more appropriate to the definition of the inverse mapping. The batch_size is required to sample an according number of noise. """ raise NotImplementedError("InvertibleModule needs to implement the `_training_step` method.") @property def inv_cfg(self) -> InverseConfig.Attribs: return self._inv_config @property def generator_module(self) -> BaseModule: return self._generator_module @generator_module.setter def generator_module(self, module: BaseModule) -> None: self._generator_module = module @property def inv(self) -> Any: return self._inv @inv.setter def inv(self, value: Any) -> None: self._inv = value @abstractmethod def _init_generator_module(self) -> BaseModule: raise NotImplementedError("InvertibleModule needs to implement the `_init_generator_module` method.") @abstractmethod def _init_inv(self) -> Any: raise NotImplementedError("InvertibleModule needs to implement the `_init_inv` method.")
[docs] def load(self, model_reference: str, load_optim: bool = False) -> InvertibleBaseModule: """ Load an inverse mapping model. The modules have to decide to save/load optimizers by themselves. Args: model_reference: Reference used to load an existing model. load_optim: Flag to indicate if the optimizer params should be loaded. """ model = load_model(model_reference, self.device) self._load_inv(model['INV'], load_optim) self._load_generator_module(self._inv_config.RUN_URI) return self
[docs] def save_checkpoint(self) -> dict: """Save a checkpoint of the inverse mapping.""" return {'G': self.generator_module.save_checkpoint(), 'INV': self._save_inv()}
@abstractmethod def _load_inv(self, inv_dict: Dict, load_optim: bool = False) -> None: """Load inversion module to memory.""" raise NotImplementedError("InvertibleModule needs to implement the `_load_inv` method.") @abstractmethod def _load_generator_module(self, model_reference: Any) -> None: """Load generator module to memory.""" raise NotImplementedError("InvertibleModule needs to implement the `_load_generator_module` method.") @abstractmethod def _save_inv(self) -> Dict: raise NotImplementedError("InvertibleModule needs to implement the `_save_inv` method.")
[docs] def train(self) -> None: """Train a inverse mapping.""" epochs = self.inv_cfg.EPOCHS rounds = self.inv_cfg.ROUNDS batch_size = self.inv_cfg.BATCH_SIZE save_checkpoint = self.inv_cfg.SAVE_CHECKPOINT artifact_checkpoint = self.inv_cfg.ARTIFACT_CHECKPOINT for epoch in tqdm(range(1, epochs + 1)): train_metrics = [] # TRAINING LOOP for _ in range(rounds): metrics = self._training_step(batch_size) train_metrics.append(metrics) # AFTER EPOCH ACTION artifacts = self.on_epoch_end(epoch, artifact_checkpoint, batch_size) # CHECKPOINT if epoch % save_checkpoint == 0: self.tracker.log_checkpoint(self.save_checkpoint(), self.tracker.fold) # HANDLE EPOCH ARTIFACTS AND TRACKING collated_train_metrics = self.tracker.collate_metrics(train_metrics) self.tracker.log_metrics(collated_train_metrics) self.tracker.log_artifacts(artifacts) self.tracker.advance_step() self.print_metric(epoch, collated_train_metrics) if self.close_tracker: self.tracker.close()