Generative Base Modules

Abstract base generative module.

class ecgan.modules.generative.base.generative_module.BaseGenerativeModule(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.base.BaseModule

Abstract base generative module containing several generative metrics.

abstract get_sample(num_samples=None, data=None)[source]

Generate a sample.

Either based on random noise (requires the amount of samples) or original data if a reconstruction-based GAN is chosen.

Return type

Tuple[Tensor, Tensor]

get_tstr()[source]

Calculate TSTR values.

Requires a validation dataset from which the data is drawn. Afterwards, data is generated either from randomly sampling the latent space (e.g. GAN based models which use a random z vector) or from retrieving a reconstructed sample from the validation data used for training.

Return type

Dict

Returns

Dict containing TSTR statistics.

get_mmd(num_samples=512, sigma=5.0)[source]

Calculate the maximum mean discrepancy.

Parameters
  • num_samples (int) -- Amount of samples used for the MMD calculation.

  • sigma (float) -- Sigma for Gaussian kernel during MMD.

Return type

float

Returns

MMD score.

Base class for GAN modules.

class ecgan.modules.generative.base.gan_module.BaseGANModule(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.generative.base.generative_module.BaseGenerativeModule

Base class from which all implemented GANs should inherit.

property generator: torch.nn.modules.module.Module

Return the generator.

Return type

Module

property discriminator: torch.nn.modules.module.Module

Return the generator.

Return type

Module

property generator_sampler: ecgan.utils.sampler.GeneratorSampler

Return the sampler used to sample from the generator.

Return type

GeneratorSampler

property discriminator_sampler: ecgan.utils.sampler.DiscriminatorSampler

Return the sampler used to sample from the discriminator.

Return type

DiscriminatorSampler

property optim_gen: ecgan.utils.optimizer.BaseOptimizer

Return the optimizer for the generator.

Return type

BaseOptimizer

property optim_disc: ecgan.utils.optimizer.BaseOptimizer

Return the optimizer for the discriminator.

Return type

BaseOptimizer

property criterion_gen: ecgan.utils.losses.GANBaseLoss

Return the criterion for the generator.

Return type

GANBaseLoss

property criterion_disc: ecgan.utils.losses.GANBaseLoss

Return the criterion for the discriminator.

Return type

GANBaseLoss

static configure()[source]

Return the default configuration of a standard GAN.

Return type

Dict

training_step(batch)[source]

Declare what the model should do during a training step using a given batch.

Parameters

batch (dict) -- The batch of real data and labels. Labels are always 0 if trained on normal data only.

Return type

dict

Returns

A dict containing the optimization metrics which shall be logged.

validation_step(batch)[source]

Declare what the model should do during a validation step.

Return type

dict

save_checkpoint()[source]

Return current model parameters.

Return type

dict

load(model_reference, load_optim=False)[source]

Load a trained module from existing model_reference.

property watch_list: List[torch.nn.modules.module.Module]

Return models that should be watched during training.

Return type

List[Module]

on_epoch_end(epoch, sample_interval, batch_size)[source]

Set actions to be executed after epoch ends.

Declare what should be done upon finishing an epoch (e.g. save artifacts or evaluate some metric).

Return type

List[Artifact]

get_sample(num_samples=None, data=None)[source]

Generate a sample.

Either based on random noise (requires the amount of samples) or original data if a reconstruction-based GAN is chosen.

Return type

Tuple[Tensor, Tensor]

Base class for encoder based GANs.

class ecgan.modules.generative.base.encoder_gan_module.BaseEncoderGANModule(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.generative.base.gan_module.BaseGANModule

Base class for GANs with an autoencoder as generator.

property discriminator_sampler: ecgan.utils.sampler.FeatureDiscriminatorSampler

Return the sampler used to sample from the discriminator.

Return type

FeatureDiscriminatorSampler

property generator_sampler: ecgan.utils.sampler.EncoderBasedGeneratorSampler

Return the sampler used to sample from the generator.

Return type

EncoderBasedGeneratorSampler

property watch_list: List[torch.nn.modules.module.Module]

Return models that should be watched during training.

Return type

List[Module]

training_step(batch)[source]

Declare what the model should do during a training step using a given batch.

Parameters

batch (dict) -- The batch of real data.

Return type

dict

Returns

A dict containing the optimization metrics which shall be logged.

validation_step(batch)[source]

Perform a validation step.

This method states the validation or inference process for one given batch.

Parameters

batch (dict) -- Dictionary containing training tensors.

Return type

dict

Returns

Dictionary with metrics to log (e.g. loss).

on_epoch_end(epoch, sample_interval, batch_size)[source]

Every sample_interval-th epoch.

  1. Sample the reconstruction of previously set fixed samples from the generator.

  2. Walk through latent space to check how data changes when walking through latent space.

Every tenth epoch and for the last 30 epochs: Check metrics using the optimization procedure for the reconstruction and discrimination loss.

Return type

List[Artifact]

save_checkpoint()[source]

Return current model parameters.

Return type

dict

load(model_reference, load_optim=False)[source]

Load a trained module from disk (file path) or wand reference.

set_fixed_samples()[source]

Set the fixed samples of the module.

Utilized in validation_step or on_epoch_end to have comparable samples across epochs. It is made sure that approximately the same amount of samples belong to class 0 and 1.

Return type

None

get_interpolated_samples(sample)[source]

Interpolate through latent space based on fixed samples.