Source code for ecgan.evaluation.tracker.wb_tracker

"""Tracker storing data using weights and biases."""
import os
import threading
import time
from logging import getLogger
from tempfile import TemporaryDirectory
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import wandb

from ecgan.evaluation.tracker.base_tracker import BaseTracker
from ecgan.utils.artifacts import Artifact, FileArtifact, ImageArtifact, ValueArtifact
from ecgan.utils.miscellaneous import load_wandb_config

logger = getLogger(__name__)


[docs]class WBTracker(BaseTracker): """ Class to manage calculating metrics and logging them with Weights & Biases (W&B). Args: entity: Weights and Biases entity name (given by W&B). project: Weights and Biases project name (chosen by user). run_name: Weights and Biases run name (chosen by user). save_pdf: Flag to save data as pdf **ADDITIONALLY** to normal plots. save_locally: Flag to retain local storage of data. """ PDF_DIR = "pdfs" def __init__( self, entity: str, project: str, run_name: str, save_locally: bool = False, save_pdf: bool = False, save_checkpoint_s3: bool = False, ): """Set basic tracking parameters.""" self.step = 1 self.run = wandb.init(project=project, name=run_name, entity=entity, reinit=True) self.run_id = wandb.run.id # type: ignore # Data will be saved locally to self.temp_dir during training and removed # if the `save_locally` flag is not explicitly set to True. self.temp_dir: Optional[TemporaryDirectory] = None self.save_locally = save_locally self.save_pdf = save_pdf self.thread_list: List[threading.Thread] = [] super().__init__(entity, project, run_name, save_pdf) if save_checkpoint_s3: import boto3 # pylint: disable=C0415 self.s3_bucket = boto3.resource('s3').Bucket(entity) self.save_checkpoint_s3 = save_checkpoint_s3 def _init_run_dir(self) -> str: """ Initialize the run directory. Returns: The file path. """ if self.save_locally: run_dir = os.path.join('results', 'runs', self.run_id) else: self.temp_dir = TemporaryDirectory() # pylint: disable=R1732 run_dir = self.temp_dir.name if self.save_pdf: os.makedirs("{}/{}".format(run_dir, WBTracker.PDF_DIR)) return run_dir
[docs] def close(self): """ Close training run. The uploading might require some time. The run is interrupted if the upload did not complete after 90 seconds. You can manually upload the data afterwards with a local sync. """ timeout = 90 start = time.time() still_alive = True remaining_time = timeout while still_alive: still_alive = False for thread in self.thread_list: thread.join(timeout=0.001) if thread.is_alive(): still_alive = True print( 'Waiting for threaded upload to finish. Force quit in {0:3} second(s) if ' 'upload does not complete.'.format(remaining_time), end='\r', ) if time.time() - start > timeout: break time.sleep(1) remaining_time -= 1 self.run.finish() if not self.save_locally and self.temp_dir is not None: self.temp_dir.cleanup()
[docs] def watch(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: """Pass a module that should be watched during training.""" if isinstance(model, list): for mod in model: wandb.watch(mod, log_freq=1) else: wandb.watch(model, log_freq=1)
[docs] def log_config(self, cfg: Dict) -> None: """Log parameters for training setup.""" wandb.config.update(cfg) # pylint: disable=E1101
[docs] def log_metrics(self, metrics: Dict) -> None: """ Take a dictionary with metrics and log content. Args: metrics: Dict in shape {metric: val, ...}. """ wandb.log(metrics, step=self.step)
[docs] def log_checkpoint(self, module_checkpoint: Dict, fold: int) -> None: """ Save a module checkpoint. The checkpoint is a dictionary containing its state dict and optimizer parameters. Args: module_checkpoint: Dictionary with model weights. fold: Current fold. Should be extracted before beginning to upload to avoid concurrency. """ self._start_thread( target=self._upload_checkpoint, args=( module_checkpoint, fold, ), )
def _start_thread(self, target: Any, args: Iterable) -> None: thread = threading.Thread(target=target, args=args) self.thread_list.append(thread) thread.start() @staticmethod def _upload_file(path: str) -> None: """Upload file to wandb.""" wandb.save(path) def _upload_checkpoint(self, module_checkpoint: Dict, fold: int) -> None: """Upload checkpoint to S3 and log it to W&B.""" path = os.path.join(self.run_dir, 'model_ep_{}_fold_{}.pt'.format(self.step, fold)) torch.save(module_checkpoint, path) model_id = str(self.run_id) + '_fold' + str(fold) art = wandb.Artifact(model_id, type='model') if self.save_checkpoint_s3: s3_path = path.replace(self.run_dir, self.run_id) self.s3_bucket.upload_file(path, s3_path) s3_reference = 's3://' + self.entity + '/' + s3_path art.add_reference(s3_reference) else: art.add_file(path) wandb.log_artifact(art) if self.save_checkpoint_s3: logger.info('Uploaded checkout to s3.') else: logger.info('Uploaded checkout to wandb.')
[docs] def log_artifacts(self, artifacts: Union[Artifact, List[Artifact]]) -> None: """ Log dictionary with artifacts to W&B. Args: artifacts: Dictionary containing artifacts to log. """ metrics_log: Dict = {} artifacts = artifacts if isinstance(artifacts, List) else [artifacts] for artifact in artifacts: if isinstance(artifact, ImageArtifact): metrics_log.update({artifact.name: wandb.Image(artifact.image)}) if self.save_pdf and artifact.figure is not None: file_name = artifact.name.replace(" ", "_") pdf_artifact = FileArtifact( artifact.name, artifact.figure, "{}_{}.pdf".format(file_name, self.step) ) path = self._local_file_save(pdf_artifact) self._start_thread(target=self._upload_file, args=("{}".format(path),)) elif isinstance(artifact, ValueArtifact): metrics_log.update({artifact.name: artifact.value}) elif isinstance(artifact, FileArtifact): path = self._local_file_save(artifact) self._start_thread(target=self._upload_file, args=(path,)) logger.info('Saved {0} in {1}'.format(artifact.name, path)) else: logger.warning("Artifact type was not found: {0}.".format(type(artifact))) if metrics_log: wandb.log(metrics_log, step=self.step)
[docs] def load_config(self, run_uri: str) -> Dict: """Load config as dict.""" return load_wandb_config(run_uri)