Source code for ecgan.utils.artifacts

"""Supported artifact types which are used as templates for tracking experiments."""
from typing import Any, Dict, Union

import matplotlib.pyplot as plt
import numpy as np

from ecgan.utils.plotting import matplotlib_render_to_array


[docs]class Artifact: """Abstract base class for artifacts supported by the tracker.""" def __init__(self, name: str): self.name = name
[docs]class ImageArtifact(Artifact): """Artifact wrapper for an image encoded as np.ndarray or a mpl Figure.""" def __init__(self, name: str, image: Union[np.ndarray, plt.Figure]): super().__init__(name) self.figure = None if isinstance(image, plt.Figure): self.image = matplotlib_render_to_array(image) self.figure = image else: self.image = image def __del__(self): """Close the internal figure when the object is destroyed.""" if self.figure is not None: plt.close(self.figure)
[docs]class ValueArtifact(Artifact): """Create an artifact which stores a single value (e.g. metric).""" def __init__(self, name: str, value: Union[float, Dict]): super().__init__(name) self.value = value
[docs]class FileArtifact(Artifact): """Create an artifact containing a dictionary which shall be saved into a yaml or pickle file.""" def __init__(self, name: str, data: Any, file_name: str): """ Initialize a file artifact. Args: name: Name of the saved data - ONLY USED FOR LOGGING, does not have to be a file name. data: Data that shall be saved. file_name: Saving location. Has to point to a .yml or .pkl file. """ super().__init__(name) self.data = data self.file_name = file_name