Source code for ecgan.evaluation.tracker.tracker_factory

"""Tracker factory to retrieve the instance of a tracker defined in the configuration file."""
from typing import Union

from ecgan.config import AdExperimentConfig, ExperimentConfig
from ecgan.evaluation.tracker.base_tracker import BaseTracker
from ecgan.evaluation.tracker.local_tracker import LocalTracker
from ecgan.evaluation.tracker.wb_tracker import WBTracker
from ecgan.utils.custom_types import TrackerType


[docs]class TrackerFactory: """Create desired tracker.""" def __call__(self, config: Union[ExperimentConfig, AdExperimentConfig]) -> BaseTracker: """Return a BaseTracker instance.""" tracker = config.TRACKER entity = tracker.ENTITY project = tracker.PROJECT run_name = tracker.EXPERIMENT_NAME save_pdf = tracker.SAVE_PDF if tracker.tracker_name == TrackerType.LOCAL: return LocalTracker(entity, project, run_name, save_pdf=save_pdf) if tracker.tracker_name == TrackerType.WEIGHTS_AND_BIASES: save_locally = tracker.LOCAL_SAVE upload_checkpoint_to_s3 = tracker.S3_CHECKPOINT_UPLOAD return WBTracker( entity, project, run_name, save_locally=save_locally, save_pdf=save_pdf, save_checkpoint_s3=upload_checkpoint_to_s3, ) raise AttributeError('Argument {0} is not set correctly.'.format(tracker.tracker_name.value))