"""Base class for encoder based GANs."""
from abc import abstractmethod
from logging import getLogger
from typing import Dict, List, Optional, cast
import torch
from numpy import argmax, histogram
from sklearn.svm import SVC
from torch import nn
from ecgan.config import EncoderGANConfig
from ecgan.evaluation.optimization import (
optimize_metric,
optimize_svm,
optimize_tau_single_error,
retrieve_labels_from_weights,
)
from ecgan.modules.generative.base.gan_module import BaseGANModule
from ecgan.utils.artifacts import Artifact, FileArtifact, ImageArtifact, ValueArtifact
from ecgan.utils.custom_types import LossMetricType, MetricType, SklearnSVMKernels
from ecgan.utils.distances import L2Distance
from ecgan.utils.interpolation import latent_walk
from ecgan.utils.miscellaneous import load_model
from ecgan.utils.sampler import EncoderBasedGeneratorSampler, FeatureDiscriminatorSampler
from ecgan.utils.transformation import MinMaxTransformation
logger = getLogger(__name__)
[docs]class BaseEncoderGANModule(BaseGANModule):
"""Base class for GANs with an autoencoder as generator."""
def __init__(
self,
cfg: EncoderGANConfig,
seq_len: int,
num_channels: int,
):
self.cfg = cast(EncoderGANConfig, cfg)
self._encoder = self._init_inverse_mapping()
super().__init__(
cfg=cfg,
seq_len=seq_len,
num_channels=num_channels,
)
self._encoder = nn.DataParallel(self.encoder)
self._encoder.to(self.device)
# Have to be set after data sampler has been added. Not possible upon creation.
self.fixed_samples: Optional[torch.Tensor] = None
self.fixed_samples_labels: Optional[torch.Tensor] = None
self.svm_mu: SVC = SVC()
self.z_mu: float = 0.0
self.z_mode: float = 0.0
self.gamma: float = 0.0 # not currently supported via saved grid search - use svm_mu for improved results
self.normalization_params: Dict = {'reconstruction_error': {}, 'discrimination_error': {}, 'latent_error': {}}
@property
def encoder(self):
return self._encoder
@property
def discriminator_sampler(self) -> FeatureDiscriminatorSampler:
return cast(FeatureDiscriminatorSampler, self._discriminator_sampler)
@property
def generator_sampler(self) -> EncoderBasedGeneratorSampler:
return cast(EncoderBasedGeneratorSampler, self._generator_sampler)
@abstractmethod
def _init_inverse_mapping(self) -> nn.Module:
raise NotImplementedError("EncoderGANModule needs to implement the `_init_inverse_mapping` method.")
@property
def watch_list(self) -> List[nn.Module]:
"""Return models that should be watched during training."""
return [self.generator, self.discriminator, self.encoder]
[docs] def training_step(
self,
batch: dict,
) -> dict:
"""
Declare what the model should do during a training step using a given batch.
Args:
batch: The batch of real data.
Return:
A dict containing the optimization metrics which shall be logged.
"""
real_data = batch['data'].to(self.device)
self.generator.train()
self.discriminator.train()
self.encoder.train()
self._prepare_train_step()
try:
disc_metric_collection = []
gen_metric_collection = []
#########################################
# Update discriminator
#########################################
for _ in range(self.cfg.DISCRIMINATOR_ROUNDS):
# Retrieve losses and update gradients
disc_losses, disc_metrics = self.criterion_disc(real_data)
self.optim_disc.optimize(disc_losses)
disc_metric_collection.extend(disc_metrics)
#########################################
# Update generator
#########################################
for _ in range(self.cfg.GENERATOR_ROUNDS):
# Retrieve losses and update gradients internally
gen_losses, gen_metrics = self.criterion_gen(real_data)
self.optim_gen.optimize(gen_losses)
gen_metric_collection.extend(gen_metrics)
return self._evaluate_train_step(disc_metrics=disc_metric_collection, gen_metrics=gen_metric_collection)
except TypeError as err:
raise TypeError('Error during training: Config parameter was not correctly set: {0}.'.format(err)) from err
def _prepare_train_step(self):
"""Can be used to set dynamic variables."""
pass
def _evaluate_train_step( # pylint: disable=R0201
self, disc_metrics: LossMetricType, gen_metrics: LossMetricType
) -> Dict:
"""
Can be used to evaluate data without impacting the training.
If multiple gen/disc rounds are used, the behavior of the metric logging and
this method is not sufficient in the current state.
"""
return {key: float(value) for (key, value) in disc_metrics + gen_metrics}
[docs] def validation_step(self, batch: dict) -> dict:
"""
Perform a validation step.
This method states the validation or inference process for one given batch.
Args:
batch: Dictionary containing training tensors.
Returns:
Dictionary with metrics to log (e.g. loss).
"""
if not isinstance(self.discriminator_sampler, FeatureDiscriminatorSampler):
raise AttributeError("Encoder based GANs currently require a feature discrimination implementation.")
data = batch['data'].to(self.device)
label = batch['label'].to(self.device)
self.discriminator.eval()
self.generator.eval()
self.encoder.eval()
l2_distance = L2Distance()
if self.fixed_samples is None:
self.set_fixed_samples()
with torch.no_grad():
x_hat, latent_vector = self.generator_sampler.sample_generator_encoder(data=data)
features_fake = self.discriminator_sampler.sample_features(x_hat)
features_real = self.discriminator_sampler.sample_features(data)
rec_error = l2_distance(data, x_hat)
disc_error = l2_distance(features_real, features_fake)
# concatenate tensors to form tensors for all batches in one epoch
self.reconstruction_error = torch.cat((self.reconstruction_error, rec_error), dim=0)
self.latent_vectors_vali = torch.cat((self.latent_vectors_vali, latent_vector), dim=0)
self.discrimination_error = torch.cat((self.discrimination_error, disc_error), dim=0)
self.label = torch.cat((self.label, label), dim=0)
return self._get_validation_results(data)
@abstractmethod
def _get_validation_results(self, data: torch.Tensor) -> Dict:
raise NotImplementedError("GANEncoder models need to implement `_get_validation_results` which will be logged.")
[docs] def on_epoch_end(self, epoch: int, sample_interval: int, batch_size: int) -> List[Artifact]:
"""
Every `sample_interval`-th epoch.
1. Sample the reconstruction of previously set fixed samples from the generator.
2. Walk through latent space to check how data changes when walking through latent space.
Every tenth epoch and for the last 30 epochs:
Check metrics using the optimization procedure for the reconstruction and discrimination loss.
"""
result: List[Artifact] = []
# Min-max normalize error:
scaler = MinMaxTransformation()
self.reconstruction_error = scaler.fit_transform(self.reconstruction_error.unsqueeze(1)).squeeze()
scaling_params = {key: value[0] for key, value in scaler.get_params().items()}
self.normalization_params['reconstruction_error'] = scaling_params
self.discrimination_error = scaler.fit_transform(self.discrimination_error.unsqueeze(1)).squeeze()
scaling_params = {key: value[0] for key, value in scaler.get_params().items()}
self.normalization_params['discrimination_error'] = scaling_params
if epoch % sample_interval == 0:
result.append(self._reconstruct_fixed_samples())
if self.fixed_samples is None or self.fixed_samples_labels is None:
raise RuntimeError("Fixed samples not set correctly.")
# Interpolate through latent space for normal and abnormal samples
result.append(self._get_interpolation_grid(self.fixed_samples[1], 'Normal Class'))
result.append(self._get_interpolation_grid(self.fixed_samples[-1], 'Abnormal Class'))
# Save the latent norms of the data
result.append(
FileArtifact(
'Latent vector distribution',
{
'latent_train': self.latent_vectors_train.cpu(),
'latent_vali': self.latent_vectors_vali.cpu(),
'labels': self.label.cpu(),
},
'latent_data_{}.pkl'.format(epoch),
)
)
# Get distribution of latent norm
latent_norm_train = torch.norm(self.latent_vectors_train.squeeze(), dim=1)
logger.info(
"latent norm train. Shape {}, mean {}, median {}, std {}".format(
latent_norm_train.shape,
torch.mean(latent_norm_train),
torch.median(latent_norm_train),
torch.std(latent_norm_train),
)
)
latent_norm_vali_abnormal = torch.norm(self.latent_vectors_vali[self.label != 0].squeeze(), dim=1)
logger.info(
"latent norm vali.Shape {}, mean {}, median {}, std {}".format(
latent_norm_vali_abnormal.shape,
torch.mean(latent_norm_vali_abnormal),
torch.median(latent_norm_vali_abnormal),
torch.std(latent_norm_vali_abnormal),
)
)
result.append(
ImageArtifact(
'Norm of latent vectors (normal train)',
self.plotter.create_histogram(
latent_norm_train.cpu().numpy(), 'Norm of latent vectors (normal train)'
),
)
)
result.append(
ImageArtifact(
'Norm of latent vectors (abnormal validation)',
self.plotter.create_histogram(
latent_norm_vali_abnormal.cpu().numpy(), 'Norm of latent vectors (abnormal vali)'
),
)
)
# Get train statistics
result.append(ValueArtifact('latent/z_mean_min', torch.min(self.latent_vectors_train).item()))
result.append(ValueArtifact('latent/z_mean_max', torch.max(self.latent_vectors_train).item()))
self.z_mu = torch.mean(latent_norm_train).item()
# Approximate mode (median can be used as a simple alternative given the distribution scatters a lot)
mode_count, mode_val = histogram(latent_norm_train.cpu().numpy(), bins=50)
self.z_mode = mode_val[argmax(mode_count)]
# Get euclidean distance from origin
latent_norm_vali = torch.norm(self.latent_vectors_vali.squeeze(), dim=1)
# For each scaled norm: subtract mode of mu to approximately shift to center of chi distribution.
scaled_latent_norm = scaler.fit_transform((latent_norm_vali.unsqueeze(1)) - self.z_mode).squeeze()
scaling_params = {key: value[0] for key, value in scaler.get_params().items()}
self.normalization_params['latent_error'] = scaling_params
# Get anomaly scores:
# 1. SVM on reconstruction error, discrimination error and latent norm
pred_latent_scaled, self.svm_mu = optimize_svm(
MetricType.FSCORE,
[
self.reconstruction_error.cpu().detach(),
self.discrimination_error.cpu().detach(),
scaled_latent_norm.cpu(),
],
self.label.cpu(),
)
result.extend(
self._get_metrics(
self.label, pred_latent_scaled, 'svm/scaled_latent', log_fscore=True, log_auroc=True, log_mcc=True
)
)
# 2. SVM on reconsturction error and discrimination error
pred_minmax, self.svm = optimize_svm(
MetricType.FSCORE,
[
self.reconstruction_error.cpu().detach(),
self.discrimination_error.cpu().detach(),
],
self.label.cpu(),
)
result.extend(
self._get_metrics(self.label, pred_minmax, 'svm/minmax', log_fscore=True, log_auroc=True, log_mcc=True)
)
pred_linear_svm, _ = optimize_svm(
MetricType.FSCORE,
[
self.reconstruction_error.cpu().detach(),
self.discrimination_error.cpu().detach(),
],
self.label.cpu(),
kernel=SklearnSVMKernels.LINEAR,
)
result.extend(
self._get_metrics(
self.label, pred_linear_svm, 'svm/linear', log_fscore=True, log_auroc=False, log_mcc=False
)
)
mmd = self.get_mmd()
result.append(ValueArtifact('generative_metric/mmd', mmd))
tstr_dict = self.get_tstr()
result.append(ValueArtifact('generative_metric/tstr', tstr_dict))
# Evaluate lambda = 0, lambda=1 for anogan and gamma=1 for vaegan
tau_range = torch.linspace(0, 2, 100).cpu().tolist()
result.append(
ValueArtifact(
'only_reconstruction_error',
optimize_tau_single_error(self.label.cpu(), self.reconstruction_error.cpu(), tau_range),
)
)
result.append(
ValueArtifact(
'only_disc_error',
optimize_tau_single_error(self.label.cpu(), self.discrimination_error.cpu(), tau_range),
)
)
result.append(
ValueArtifact(
'only_latent_error',
optimize_tau_single_error(self.label.cpu(), scaled_latent_norm.cpu(), tau_range),
)
)
# Optimize F-score every 10 epochs
if epoch % 10 == 0:
# if False:
lambda_search_range = [torch.linspace(0, 1, 50).numpy().tolist()]
best_params = optimize_metric(
MetricType.FSCORE,
errors=[self.reconstruction_error.cpu(), self.discrimination_error.cpu()],
taus=torch.linspace(0, 2, 100).numpy().tolist(),
params=lambda_search_range,
ground_truth_labels=self.label.cpu(),
)
logger.info(
"Best params: {} for data {}.".format(best_params, torch.unique(self.label, return_counts=True))
)
self.tau = best_params[0][1]
self.lambda_ = best_params[0][2]
result.append(
ValueArtifact(
'grid/lambda_tau',
float(self.tau),
)
)
result.append(
ValueArtifact(
'grid/lambda',
float(self.lambda_),
)
)
predictions = retrieve_labels_from_weights(
errors=[self.reconstruction_error.cpu(), self.discrimination_error.cpu()],
tau=self.tau,
weighting_params=[self.lambda_],
)
result.extend(
self._get_metrics(
self.label,
predictions,
'grid/lambda',
log_fscore=True,
log_auroc=True,
log_mcc=True,
)
)
result.extend(self._on_epoch_end_addition(epoch, sample_interval))
self._reset_internal_tensors()
return result
def _on_epoch_end_addition(self, epoch: int, sample_interval: int) -> List[Artifact]: # pylint: disable=R0201,W0613
return []
def _reconstruct_fixed_samples(
self,
) -> ImageArtifact:
"""
Visualize real (fixed) samples and the faked_samples which aim to reconstruct them.
Returns:
Artifact containing a comparison of the real and reconstructed fixed samples.
"""
if self.fixed_samples is None or self.fixed_samples_labels is None:
raise RuntimeError("Fixed samples not set correctly.")
with torch.no_grad():
faked_samples, _ = self.generator_sampler.sample_generator_encoder(data=self.fixed_samples)
samples = torch.empty(
(
2 * self.num_fixed_samples,
faked_samples.shape[1],
faked_samples.shape[2],
)
)
labels = torch.empty((2 * self.num_fixed_samples, 1))
for i in range(self.num_fixed_samples):
samples[2 * i] = self.fixed_samples[i]
samples[2 * i + 1] = faked_samples[i]
labels[2 * i] = self.fixed_samples_labels[i]
labels[2 * i + 1] = self.fixed_samples_labels[i]
return ImageArtifact(
'Fixed Generator Samples',
self.plotter.get_sampling_grid(
samples,
label=labels,
),
)
[docs] def save_checkpoint(self) -> dict:
"""Return current model parameters."""
return {
'GEN': self.generator.state_dict(),
'DIS': self.discriminator.state_dict(),
'ENC': self.encoder.state_dict(),
'GEN_OPT': self.optim_gen.state_dict(),
'DIS_OPT': self.optim_disc.state_dict(),
'ANOMALY_DETECTION': {
'SVM': self.svm,
'LAMBDA': self.lambda_,
'GAMMA': self.gamma,
'TAU': self.tau,
'Z_MU': self.z_mu,
'Z_MODE': self.z_mode,
'SVM_MU': self.svm_mu,
'NORM_PARAMS': self.normalization_params,
},
'FIXED_SAMPLES': self.fixed_samples.detach().cpu().tolist()
if isinstance(self.fixed_samples, torch.Tensor)
else None, # temporary for tracking and graphing
}
[docs] def load(self, model_reference: str, load_optim: bool = False):
"""Load a trained module from disk (file path) or wand reference."""
model = load_model(model_reference, self.device)
self.generator.load_state_dict(model['GEN'], strict=False)
self.discriminator.load_state_dict(model['DIS'], strict=False)
self.encoder.load_state_dict(model['ENC'], strict=False)
if load_optim:
self.optim_gen.load_existing_optim(model['GEN_OPT'])
self.optim_disc.load_existing_optim(model['DIS_OPT'])
logger.info('Loading existing {0} model completed.'.format(self.__class__.__name__))
self.svm = model['ANOMALY_DETECTION']['SVM']
self.tau = model['ANOMALY_DETECTION']['TAU']
self.lambda_ = model['ANOMALY_DETECTION']['LAMBDA']
self.gamma = model['ANOMALY_DETECTION']['GAMMA']
self.z_mu = model['ANOMALY_DETECTION']['Z_MU']
self.z_mode = model['ANOMALY_DETECTION']['Z_MODE']
self.svm_mu = model['ANOMALY_DETECTION']['SVM_MU']
self.normalization_params = model['ANOMALY_DETECTION']['NORM_PARAMS']
return self
[docs] def set_fixed_samples(self) -> None:
"""
Set the fixed samples of the module.
Utilized in validation_step or on_epoch_end to have comparable samples across epochs.
It is made sure that approximately the same amount of samples belong to
class 0 and 1.
"""
fixed_samples_normal = self.vali_dataset_sampler.sample_class(self.num_fixed_samples // 2, 0)
fixed_samples_anormal = self.vali_dataset_sampler.sample_class(
self.num_fixed_samples - self.num_fixed_samples // 2, 1
)
self.fixed_samples = torch.cat((fixed_samples_normal['data'], fixed_samples_anormal['data'])).to(self.device)
self.fixed_samples_labels = torch.cat((fixed_samples_normal['label'], fixed_samples_anormal['label'])).to(
self.device
)
def _get_interpolation_grid(self, base_sample: torch.Tensor, class_name: str) -> ImageArtifact:
interpolated_samples = self.get_interpolated_samples(base_sample.unsqueeze(0))
return ImageArtifact(
'{0} Interpolation Grid'.format(class_name),
self.plotter.get_sampling_grid(
interpolated_samples,
row_width=11,
scale_per_batch=True,
max_num_series=interpolated_samples.shape[0],
),
)
[docs] def get_interpolated_samples(self, sample: torch.Tensor):
"""Interpolate through latent space based on fixed samples."""
# Investigate the latent space using one of the fixed samples and performing a latent walk.
with torch.no_grad():
_x_hat, inverse_mapping = self.get_sample(data=sample)
std = torch.std(inverse_mapping).item()
logger.info("Standard error of latent space is {}.".format(std))
walk_space = torch.linspace(-std, std, 11)
interpolated_samples = latent_walk(
inverse_mapping,
self.generator,
walk_range=walk_space,
device=self.device,
latent_dims=inverse_mapping.shape[2],
)
return interpolated_samples