Source code for ecgan.evaluation.tracker.base_tracker

"""Base class for trackers."""
import os
from abc import ABC, abstractmethod
from statistics import mean
from typing import Dict, List, Union

import torch
import yaml

from ecgan.utils.artifacts import Artifact, FileArtifact
from ecgan.utils.miscellaneous import save_pickle
from ecgan.visualization.plotter import BasePlotter


[docs]class BaseTracker(ABC): """Base class for trackers.""" def __init__(self, entity: str, project: str, run_name: str, save_pdf: bool = False): """Set basic tracking parameters.""" self.entity = entity self.project = project self.run_name = run_name self.step = 1 self.plotter = BasePlotter() self.fold = 1 self.run_dir = self._init_run_dir() os.makedirs(self.run_dir, exist_ok=True) self.save_pdf = save_pdf
[docs] def advance_step(self): """Advance by one step.""" self.step += 1
[docs] def advance_fold(self): """Advance by one fold.""" self.fold += 1
@abstractmethod def _init_run_dir(self) -> str: """Initialize the run directory and returns the file path.""" raise NotImplementedError("Tracker needs to implement the `_init_run_dir` method.")
[docs] @abstractmethod def close(self): """Close training run.""" raise NotImplementedError("Tracker needs to implement the `close` method.")
[docs] @staticmethod def collate_metrics(metrics: List[Dict]) -> Dict: """ Transform a list with dictionaries to a single dictionary. All values are collated in their respective keys and are then averaged. Args: metrics: A list of metrics. Returns: Dictionary with the collated metrics. """ metrics_collated = {} key_list: List = [] for metric_dict in metrics: key_list.extend(metric_dict.keys()) key_list = list(set(key_list)) for key in key_list: metric_list = [metric.get(key) for metric in metrics if metric.get(key) is not None] metrics_collated.update({key: mean(metric_list)}) # type: ignore return metrics_collated
[docs] @abstractmethod def watch(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: """Pass a module that should be watched during training.""" pass
[docs] @abstractmethod def log_config(self, cfg: Dict) -> None: """Log parameters for training setup.""" raise NotImplementedError("Tracker needs to implement the `log_config` method.")
[docs] @abstractmethod def log_metrics(self, metrics: Dict) -> None: """ Take a dictionary with metrics and log content. Args: metrics: Dict in shape {metric: val, ...} """ raise NotImplementedError("Tracker needs to implement the `log_metrics` method.")
[docs] @abstractmethod def log_checkpoint(self, module_checkpoint: Dict, fold: int) -> None: """ Save a module checkpoint. The checkpoint is a dictionary containing its state dict and optimizer parameters. Args: module_checkpoint: Dictionary with model weights. fold: Current fold. Should be extracted before beginning to upload to avoid concurrency. """ raise NotImplementedError("Tracker needs to implement the `log_checkpoint` method.")
[docs] @abstractmethod def log_artifacts(self, artifacts: Union[Artifact, List[Artifact]]) -> None: """ Log dictionary with artifacts. Args: artifacts: Dictionary containing artifacts to log. """ raise NotImplementedError("Tracker needs to implement the `log_artifacts` method.")
[docs] @abstractmethod def load_config(self, run_uri: str) -> Dict: """ Load config. Args: run_uri: Path pointing to project root. """ raise NotImplementedError("Tracker needs to implement the `load_config` method.")
def _local_file_save(self, artifact: FileArtifact) -> str: """ Self data to local file system. Args: artifact: artifact containing the file. Returns: The local path the file is saved to. """ file_name: str = artifact.file_name save_path = os.path.join(self.run_dir, file_name) if file_name.endswith('.pdf'): save_path = os.path.join(self.run_dir, file_name) self.plotter.save_plot(artifact.data, save_path) elif file_name.endswith('.yml'): with open(save_path, 'w', encoding='utf-8') as out_file: yaml.dump(artifact.data, out_file) elif file_name.endswith('.pkl'): save_pickle(artifact.data, self.run_dir, file_name) else: with open(save_path, 'w', encoding='utf-8') as out_file: out_file.write(str(artifact.data)) return save_path