"""Describes different plotting classes."""
import math
from typing import List, Optional, Tuple, Union
import matplotlib.colors as mplcolors
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Axes, Figure
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from scipy import fftpack
from torch import Tensor
from ecgan.config import TrainerConfig
from ecgan.utils.custom_types import PlotterType, Transformation
from ecgan.utils.miscellaneous import to_numpy, to_torch
[docs]def matplotlib_prep(size: Tuple[int, int], subplots: int = 1, y_lim: Optional[Tuple] = None) -> Tuple[Figure, Axes]:
"""
Return a figure and an axis for a matplotlib plot with a given size.
Args:
size: Target (width, height) of the plot.
subplots: amount of subplots of the figure.
y_lim: Visual y limits of the plot.
Returns:
The processed figure and axes.
"""
# Standard dpi in matplotlib is 100.
dpi = 100
fig, ax = plt.subplots(subplots, figsize=(size[0] / dpi, size[1] / dpi))
plt.tight_layout()
if y_lim is not None:
plt.ylim(y_lim[0], y_lim[1])
return fig, ax
[docs]class BasePlotter:
"""Base class for plotting classes. Creates its plots by simple call to plt.plot."""
@staticmethod
def _get_data_axes(data: Union[Tensor, np.ndarray]) -> np.ndarray:
"""
Get the x and y axis corresponding to a given time series.
The BasePlotter assumes a equidistantly sampled time series in which
case the x-values of the plot are set to [0, 1,..., seq_len].
"""
return np.array([range(data.shape[0]), data])
[docs] def create_plot(
self,
data: np.ndarray,
color: str = 'blue',
size: Tuple[int, int] = (256, 256),
y_lim: Optional[Tuple[float, float]] = None,
label: Optional[int] = None,
) -> Figure:
"""
Generate a plot from given data (default: 2D time series).
How the plot is created is subject to the given class implementation. The plot is returned as a mpl Figure.
Args:
data: A list of data points representing a 2D time series.
color: The color in which to draw the plot.
size: The size the plot image is resized to. Resizing is done via cubic interpolation.
y_lim: Limit the y-axis of the plot to the floating tuple.
label: Optional label which can be drawn into the plot.
Returns:
The generated images.
"""
fig, axs = matplotlib_prep(size, y_lim=y_lim)
self._plot(data=data, axes=axs, color=color, label=label)
return fig
[docs] def get_sampling_grid(
self,
sample_data: Union[Tensor, np.ndarray],
max_num_series: int = 16,
row_width: int = 4,
color: str = 'blue',
scale_per_batch: bool = False,
label: Optional[Union[Tensor, np.ndarray]] = None,
x_axis: bool = True,
y_axis: bool = True,
fig_size: Optional[Tuple[float, float]] = None,
) -> Figure:
"""
Get sampled time series data image grid.
Args:
sample_data: The data that shall be visualized with shape (amount_of_series, measurements, channels).
max_num_series: Maximum amount of samples visualized.
row_width: Row width.
color: Sets the colour of the plots.
scale_per_batch: Set the y-limit for each plot to min/max of the batch (per channel).
label: Optional labels parameter, which is written to the plots.
x_axis: Flag indicating whether the x-axis should be visible.
y_axis: Flag indicating whether the y-axis should be visible.
fig_size: Optional size for the figure in inches.
Returns:
Figure of the image grid.
"""
sample_data_ = to_numpy(sample_data)
if fig_size is None:
# Make the figure wider if more than the default 4 figures are in one row.
extended_row_width = row_width / 5 if row_width > 4 else 1.0
arbitrary_extension = max_num_series / (row_width * 5)
arbitrary_extension = 1.0 if arbitrary_extension < 1 else arbitrary_extension
# DIN A4 size in inches: (11.69, 8.27)
extended_height = 8.27 * arbitrary_extension
extended_width = 11.69 * arbitrary_extension * extended_row_width
fig_size = (extended_width, extended_height)
fig = plt.figure(figsize=fig_size)
num_data, seq_len, num_channels = sample_data_.shape
num_data = min(num_data, max_num_series)
# Univariate series
if num_channels == 1:
sample_data_.reshape((num_data, seq_len, num_channels))
num_of_rows = math.ceil(num_data / row_width)
outer_grid = GridSpec(num_of_rows, row_width, wspace=0.1, hspace=0.1)
y_lim = None
if scale_per_batch:
y_lim = (np.min(sample_data_), np.max(sample_data_))
for i in range(num_of_rows):
for j in range(row_width):
inner_grid = GridSpecFromSubplotSpec(
num_channels,
1,
subplot_spec=outer_grid[i * row_width + j],
wspace=0.3,
hspace=0.15,
)
for k in range(num_channels):
plot_number = i * row_width + j
if plot_number >= num_data:
break
ax = plt.Subplot(fig, inner_grid[k])
ax.get_xaxis().set_visible(x_axis)
ax.get_yaxis().set_visible(y_axis)
if y_lim is not None:
ax.set_ylim(y_lim[0], y_lim[1])
if k != num_channels - 1:
ax.get_xaxis().set_visible(False)
label_ = int(label[plot_number]) if label is not None else None
self._plot(
sample_data_[plot_number, :, k],
label=label_,
axes=ax,
color=color,
)
fig.add_subplot(ax)
outer_grid.tight_layout(fig)
return fig
def _plot(
self,
data: np.ndarray,
axes: Axes,
color: str,
label: Optional[int] = None,
) -> None:
"""
Plot data to given axes.
Args:
data: Data to plot.
axes: Axes to plot the data on.
color: Graph color.
label: Numeric label of the data class.
"""
transformed_data = self._get_data_axes(data)
x, y = transformed_data
axes.plot(x, y, color=color)
if label is not None:
axes.set_title('Class: {}'.format(label), loc='center')
[docs] def save_sampling_grid(
self,
sample_data: Union[Tensor, np.ndarray],
file_location: str,
color: str = 'blue',
max_num_series: int = 16,
scale_per_batch=False,
row_width: int = 4,
label: Optional[Union[Tensor, np.ndarray]] = None,
) -> None:
"""
Save sampled time series data to an image grid.
Args:
sample_data: The data that shall be visualized with shape (amount_of_series, 1, measurements).
color: Color of the plot.
file_location: Path to file on local system. Should be a UNIQUE identifier.
max_num_series: Maximum amount of samples visualized.
row_width: Width of a row.
scale_per_batch: Set the y-limit for each plot to min/max of the batch (per channel).
label: Optional labels parameter, which is written to the plots.
"""
image_grid = self.get_sampling_grid(
sample_data,
max_num_series,
row_width,
color=color,
scale_per_batch=scale_per_batch,
label=label,
)
image_grid.savefig(file_location)
[docs] @staticmethod
def create_histogram(
data: np.ndarray,
title: str,
x_label: str = '',
y_label: str = '',
bins: int = 50,
color: str = 'g',
) -> Figure:
"""Create a histogram of given data."""
fig = plt.figure()
plt.hist(data, bins, density=True, facecolor=color, alpha=0.75)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.grid(True)
return fig
[docs] @staticmethod
def save_plot(
plot: Union[Figure, np.ndarray],
file_location: str,
) -> None:
"""Save a plot (encoded as np.ndarray) to file_location."""
if isinstance(plot, Figure):
plot.savefig(file_location)
plt.close(plot)
else:
plt.imsave(file_location, np.ascontiguousarray(plot))
[docs] @staticmethod
def create_error_plot( # pylint: disable=R0913
data_lined: Union[np.ndarray, Tensor],
data_dashed: Union[np.ndarray, Tensor],
heatmap_data: Optional[np.ndarray] = None,
x_axis: Optional[np.ndarray] = None,
data_range: Optional[Tuple[float, float]] = None,
color_lined: str = 'blue',
color_dashed: str = 'red',
color_map: str = 'plasma',
x_label: str = '',
y_label: str = '',
title: str = '',
) -> Figure:
"""
Create a plot with two graphs visualizing the difference of the two samples and a heatmap.
Args:
data_lined: The sample data that is depicted as a solid line.
data_dashed: The sample data that is depicted as a dashed line.
heatmap_data: Optional data for the drawing of the heatmap.
The default simply computes the absolute value of (data_dashed - data_lined)
x_axis: Sampling of the x-axis. Default assumes range(1, len_of_samples).
data_range: Range of the heatmap. Dynamic by default, requires (min,max) otherwise.
color_lined: Color of the lined plot.
color_dashed: Color of the dashed plot.
color_map: Color map for the heatmap.
x_label: Label of the x-axis.
y_label: Label of the y-axis.
title: Title for the plot.
Returns:
Heatmap Figure.
"""
data_lined = to_numpy(data_lined)
data_dashed = to_numpy(data_dashed)
if x_axis is None:
len_ = len(data_lined) if len(data_lined) > len(data_dashed) else len(data_dashed)
x_axis = np.arange(len_)
if heatmap_data is None:
heatmap_data = abs(data_dashed - data_lined)
_, axes = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [6, 1]})
axes[0].plot(x_axis, data_lined, color=color_lined)
axes[0].plot(x_axis, data_dashed, color=color_dashed, linestyle="--")
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
axes[1].get_yaxis().set_visible(False)
if data_range is not None:
norm = mplcolors.Normalize(vmin=data_range[0], vmax=data_range[1])
heatmap = axes[1].imshow([heatmap_data], cmap=color_map, aspect='auto', norm=norm)
else:
heatmap = axes[1].imshow([heatmap_data], cmap=color_map, aspect='auto')
plt.colorbar(heatmap, ax=axes)
return plt.gcf()
[docs]class FourierPlotter(BasePlotter):
"""Plotter implementation to plot data in the Fourier-domain as a frequency-histogram."""
@staticmethod
def _transform_to_complex(data: Union[Tensor, np.ndarray]) -> Tensor:
"""
Transform a tensor to a complex tensors.
Args:
data: Input data of shape (batch_size, seq_len, 2*num_channels)
Returns:
Tensor of shape (batch_size, seq_length, num_channels) with complex entries.
"""
data = to_torch(data)
batch_size, seq_length, twice_num_channels = data.shape
num_channels = twice_num_channels // 2
dtype = torch.cfloat if data.data == torch.float else torch.cdouble
result = torch.empty((num_channels, batch_size, seq_length), dtype=dtype)
for channel in range(num_channels):
real = data[:, :, 2 * channel]
imag = data[:, :, 2 * channel + 1]
result[channel] = torch.complex(real, imag)
return result.permute(1, 2, 0)
[docs] def get_sampling_grid(
self,
sample_data: Union[Tensor, np.ndarray],
max_num_series: int = 16,
row_width: int = 4,
color: str = 'blue',
scale_per_batch: bool = False,
label: Optional[Union[Tensor, np.ndarray]] = None,
x_axis: bool = True,
y_axis: bool = True,
fig_size: Optional[Tuple[float, float]] = None,
) -> Figure:
"""
Get sampled time series data image grid.
Args:
sample_data: The data that shall be visualized. The tensor is assumed to be of shape
(batch_size, seq_length, 2 * num_of_channels), where each two channels form the real and imaginary parts
of the Fourier coefficients.
max_num_series: Maximum amount of samples visualized.
row_width: Amount of columns if grid shall not be symmetric.
color: Sets the colour of the plots.
scale_per_batch: Set the y-limit for each plot to min/max of the batch (per channel).
label: Optional labels parameter, which is written to the plots.
x_axis: Flag indicating whether the x-axis should be visible.
y_axis: Flag indicating whether the y-axis should be visible.
fig_size: Optional size for the figure in inches.
Returns:
Image grid as Figure.
"""
# First restore the tensor to a complex number.
complex_tensor = self._transform_to_complex(sample_data)
# Now the sampling can continue as declared in the super class.
return super().get_sampling_grid(
complex_tensor,
max_num_series,
row_width,
color,
scale_per_batch,
label,
)
@staticmethod
def _transform_data(data: Union[Tensor, np.ndarray]) -> np.ndarray:
"""Transform a data tensor in Fourier domain to a histogram of form (freq, abs(data))."""
data = to_numpy(data)
seq_len = data.shape[0]
sample_rate = 2 * seq_len # Nyquist theorem
freqs: np.ndarray = fftpack.fftfreq(len(data)) * sample_rate
return freqs
[docs] def create_plot(
self,
data: np.ndarray,
color: str = 'blue',
size: Tuple[int, int] = (256, 256),
y_lim: Optional[Tuple[float, float]] = None,
label: Optional[int] = None,
) -> Figure:
"""Create a plot of the data using the settings of the plotting class."""
transformed_data = self._transform_data(data)
frequencies = transformed_data
abs_ = np.abs(data)
x = range(data.shape[0])
y = torch.fft.ifftn(torch.from_numpy(data), norm='ortho').numpy()
y = np.real(y)
bottom, top = plt.ylim()
fig, ax = matplotlib_prep(size, subplots=2, y_lim=y_lim)
ax[1].plot(x, y, color=color)
plt.ylim(bottom, top)
if label is not None:
ax[0].set_title('Class: {}'.format(label), loc='center')
ax[0].stem(frequencies, abs_, markerfmt=" ")
return fig
[docs]class ScatterPlotter(BasePlotter):
"""Plotter specialized on scatter plots for large datasets."""
[docs] @staticmethod
def truncate_colormap(
cmap: mplcolors.Colormap, min_val: float = 0.0, max_val: float = 1.0, n_cmap: int = 100
) -> mplcolors.LinearSegmentedColormap:
"""
Truncate a given matplotlib colormap to a specific range.
Args:
cmap: matplotlib colormap object.
min_val: Minimum bound of colormap.
max_val: Maximum bound of colormap.
n_cmap: Number of samples from original cmap.
Returns:
New colormap object
"""
new_cmap = mplcolors.LinearSegmentedColormap.from_list(
'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=min_val, b=max_val),
cmap(np.linspace(min_val, max_val, n_cmap)),
)
return new_cmap
[docs] @staticmethod
def plot_scatter(
data: np.ndarray,
target: Union[List[int], object],
fig_title: Optional[str],
classes: Optional[List[str]],
dpi: int = 300,
alpha: float = 1,
cmap: str = 'plasma',
) -> Figure:
"""
Visualizes the resulting low dimensional (2D or 3D) embedding.
Args:
data: Low dimensional embedding, either 2D or 3D plots, i.e. shape (n_samples, 2) or (n_samples, 3).
target: Target values used for visualization encoded as integers.
fig_title: Description of the model saved. Should include the name of the embedding.
classes: Class names according to the encoding in `target`.
dpi: DPI of the resulting figure.
alpha: Alpha blending value. Is between 0 (transparent) and 1 (opaque).
cmap: Colormap.
Returns:
Figure containing the scatter plot.
"""
assert (
data.shape[1] == 2 or data.shape[1] == 3
), 'Can only visualize 2D or 3D embeddings input shape was {0}.'.format(data.shape)
fig = plt.figure(figsize=(14, 10))
fig._set_dpi(dpi) # pylint: disable=W0212
if data.shape[1] == 2:
ax = fig.add_subplot(111)
scatterplot = ax.scatter(
data[:, 0],
data[:, 1],
alpha=alpha,
cmap=cmap,
c=target,
s=1,
)
plt.setp(ax, xticks=[], yticks=[])
else:
ax = fig.add_subplot(111, projection='3d')
scaling = 3 / np.log10(data.shape[0]) if data.shape[0] > 5000 else 30 / np.log10(data.shape[0])
scatterplot = ax.scatter(
data[:, 0],
data[:, 1],
data[:, 2],
s=scaling,
alpha=1.0,
cmap=cmap,
c=target,
)
plt.setp(ax, xticks=[], yticks=[], zticks=[])
if classes is not None:
cbar = fig.colorbar(scatterplot, boundaries=np.arange((len(classes) + 1)) - 0.5)
cbar.set_ticks(np.arange(len(classes)))
cbar.set_ticklabels(classes)
if fig_title is not None:
plt.title(fig_title, fontsize=18, y=1.03)
return fig
[docs] @staticmethod
def plot_interpolation_path(
data: np.ndarray,
labels: np.ndarray,
trace: np.ndarray,
classes: Optional[List[str]],
fig_size: Tuple[float, float] = (10, 7.5),
cmap: str = 'plasma',
cmap_range: Tuple[float, float] = (0.0, 1.0),
path_color: str = 'r',
scatter_alpha: float = 1.0,
) -> Figure:
"""Plot a trace between points in a scatter plot."""
fig, ax = plt.subplots(figsize=fig_size)
# Plot latent embedding
new_cmap = ScatterPlotter.truncate_colormap(plt.get_cmap(cmap), cmap_range[0], cmap_range[1])
scatterplot = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap=new_cmap, alpha=scatter_alpha, s=0.5)
# Create legend
if classes is not None:
cb = fig.colorbar(scatterplot, shrink=0.75, boundaries=np.arange((len(classes) + 1)) - 0.5)
cb.solids.set(alpha=1)
cb.set_ticks(np.arange(len(classes)))
cb.set_ticklabels(classes)
# Plot walk route
if trace is None:
raise RuntimeError("No trace supplied to `ScatterPlotter.plot_interpolation_path`.")
ax.plot(trace[:, 0], trace[:, 1], c=path_color)
ax.scatter(trace[0, 0], trace[0, 1], s=0.5, c=path_color, marker='X', alpha=1.0)
ax.arrow(
trace[-2, 0],
trace[-2, 1],
trace[-1, 0] - trace[-2, 0],
trace[-1, 1] - trace[-2, 1],
color=path_color,
head_width=0.1,
head_length=0.1,
)
# fig.tight_layout()
ax.axis('off')
return fig
[docs]class PlotterFactory:
"""Used to retrieve instance of desired plotter."""
[docs] @staticmethod
def choose_class(plotter_type: PlotterType):
"""Choose the correct class based on the provided plotter name."""
if plotter_type == PlotterType.FOURIER:
return FourierPlotter
if plotter_type == PlotterType.BASE:
return BasePlotter
if plotter_type == PlotterType.SCATTER:
return ScatterPlotter
raise AttributeError('Argument {0} is not set correctly.'.format(plotter_type))
def __call__(self, plotter_type: PlotterType, **kwargs) -> BasePlotter:
"""Create and return Plotter object."""
cls = PlotterFactory.choose_class(plotter_type)
base_plotter: BasePlotter = cls()
return base_plotter
[docs] @staticmethod
def from_config(train_cfg: TrainerConfig) -> BasePlotter:
"""Generate a plotter from a config dictionary."""
if train_cfg.transformation == Transformation.FOURIER:
plotter = PlotterFactory()(PlotterType.FOURIER)
else:
plotter = PlotterFactory()(PlotterType.BASE)
return plotter
[docs]def visualize_reconstruction(
series: torch.Tensor,
plotter: BasePlotter,
max_intermediate_samples: int = 10,
) -> Figure:
"""
Visualize a fixed amount of series for a variable amount of input series.
Total steps=max_intermediate_samples+original sample+final sample.
Args:
series: Tensor of series that shall be reconstructed.
plotter: The plotter to use for the visualization.
max_intermediate_samples: Maximum amount of interpolation steps.
Returns:
A mpl Figure containing a sequence of series.
"""
max_intermediate_samples = len(series) if len(series) // max_intermediate_samples == 0 else max_intermediate_samples
idx_list = list(range(0, len(series), len(series) // max_intermediate_samples))
idx_list.append(len(series) - 1)
return plotter.get_sampling_grid(series[idx_list])