Generative Base Modules
Abstract base generative module.
- class ecgan.modules.generative.base.generative_module.BaseGenerativeModule(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.base.BaseModule
Abstract base generative module containing several generative metrics.
- abstract get_sample(num_samples=None, data=None)[source]
Generate a sample.
Either based on random noise (requires the amount of samples) or original data if a reconstruction-based GAN is chosen.
- Return type
Tuple
[Tensor
,Tensor
]
- get_tstr()[source]
Calculate TSTR values.
Requires a validation dataset from which the data is drawn. Afterwards, data is generated either from randomly sampling the latent space (e.g. GAN based models which use a random z vector) or from retrieving a reconstructed sample from the validation data used for training.
- Return type
Dict
- Returns
Dict containing TSTR statistics.
Base class for GAN modules.
- class ecgan.modules.generative.base.gan_module.BaseGANModule(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.generative.base.generative_module.BaseGenerativeModule
Base class from which all implemented GANs should inherit.
- property generator: torch.nn.modules.module.Module
Return the generator.
- Return type
Module
- property discriminator: torch.nn.modules.module.Module
Return the generator.
- Return type
Module
- property generator_sampler: ecgan.utils.sampler.GeneratorSampler
Return the sampler used to sample from the generator.
- Return type
- property discriminator_sampler: ecgan.utils.sampler.DiscriminatorSampler
Return the sampler used to sample from the discriminator.
- Return type
- property optim_gen: ecgan.utils.optimizer.BaseOptimizer
Return the optimizer for the generator.
- Return type
- property optim_disc: ecgan.utils.optimizer.BaseOptimizer
Return the optimizer for the discriminator.
- Return type
- property criterion_gen: ecgan.utils.losses.GANBaseLoss
Return the criterion for the generator.
- Return type
- property criterion_disc: ecgan.utils.losses.GANBaseLoss
Return the criterion for the discriminator.
- Return type
- training_step(batch)[source]
Declare what the model should do during a training step using a given batch.
- Parameters
batch (
dict
) -- The batch of real data and labels. Labels are always 0 if trained on normal data only.- Return type
dict
- Returns
A dict containing the optimization metrics which shall be logged.
- validation_step(batch)[source]
Declare what the model should do during a validation step.
- Return type
dict
- load(model_reference, load_optim=False)[source]
Load a trained module from existing model_reference.
- property watch_list: List[torch.nn.modules.module.Module]
Return models that should be watched during training.
- Return type
List
[Module
]
Base class for encoder based GANs.
- class ecgan.modules.generative.base.encoder_gan_module.BaseEncoderGANModule(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.generative.base.gan_module.BaseGANModule
Base class for GANs with an autoencoder as generator.
- property discriminator_sampler: ecgan.utils.sampler.FeatureDiscriminatorSampler
Return the sampler used to sample from the discriminator.
- Return type
- property generator_sampler: ecgan.utils.sampler.EncoderBasedGeneratorSampler
Return the sampler used to sample from the generator.
- Return type
- property watch_list: List[torch.nn.modules.module.Module]
Return models that should be watched during training.
- Return type
List
[Module
]
- training_step(batch)[source]
Declare what the model should do during a training step using a given batch.
- Parameters
batch (
dict
) -- The batch of real data.- Return type
dict
- Returns
A dict containing the optimization metrics which shall be logged.
- validation_step(batch)[source]
Perform a validation step.
This method states the validation or inference process for one given batch.
- Parameters
batch (
dict
) -- Dictionary containing training tensors.- Return type
dict
- Returns
Dictionary with metrics to log (e.g. loss).
- on_epoch_end(epoch, sample_interval, batch_size)[source]
Every sample_interval-th epoch.
Sample the reconstruction of previously set fixed samples from the generator.
Walk through latent space to check how data changes when walking through latent space.
Every tenth epoch and for the last 30 epochs: Check metrics using the optimization procedure for the reconstruction and discrimination loss.
- Return type
List
[Artifact
]
- load(model_reference, load_optim=False)[source]
Load a trained module from disk (file path) or wand reference.