Loss functions
Implementation of various loss functions for PyTorch models.
- class ecgan.utils.losses.SupervisedLoss[source]
Bases:
ecgan.utils.configurable.Configurable
Base class for supervised loss functions.
- class ecgan.utils.losses.L2Loss(reduction='mean')[source]
Bases:
ecgan.utils.losses.SupervisedLoss
Wrapper over the mean squared error loss of the torch module.
- class ecgan.utils.losses.BCELoss(reduction='mean')[source]
Bases:
ecgan.utils.losses.SupervisedLoss
Wrapper over the binary cross entropy loss of the torch module.
- class ecgan.utils.losses.CrossEntropyLoss(reduction='mean')[source]
Bases:
ecgan.utils.losses.SupervisedLoss
Wrapper over the cross entropy loss of the torch module.
- class ecgan.utils.losses.KLLoss[source]
Bases:
object
Kullback-Divergence loss for the usage in variational auto-encoders.
- static forward(mean_value, log_var)[source]
Calculate Kullback-Leibler divergence for standard Gaussian distribution.
Calculate KL divergence for a given expected value and log variance. The input tensors are expected to be in shape (N x DIM) where N is the number of samples and DIM is the dimension of the multivariate Gaussian. The result will be the average KL-Divergence of a batch of distributions and a unit Gaussian.
- Return type
Tensor
- class ecgan.utils.losses.GANBaseLoss(discriminator_sampler, generator_sampler)[source]
Bases:
ecgan.utils.configurable.Configurable
Base loss class for custom GAN losses.
- class ecgan.utils.losses.BceGeneratorLoss(discriminator_sampler, generator_sampler, reduction)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
BCE Loss using the PyTorch implementation.
- class ecgan.utils.losses.BceDiscriminatorLoss(discriminator_sampler, generator_sampler, reduction)[source]
Bases:
ecgan.utils.losses.BceGeneratorLoss
Two component BCE Loss using the PyTorch implementation.
The class assumes that corresponding BaseSampler`s for each component are implemented. The fake data is sampled by the `sample method of the provided generator. The BCE loss is commonly used when optimizing the discriminator of a vanilla GAN.
- class ecgan.utils.losses.WassersteinDiscriminatorLoss(discriminator_sampler, generator_sampler, gradient_penalty_weight=None, clipping_bound=None)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
Wasserstein loss for the discriminator.
- static apply_gp(input_tensor, target_tensor)[source]
GP penalty is applied outside the forward call during optimization.
- Return type
Tensor
- forward(training_data)[source]
Calculate the Wasserstein distance and minimize it using a given optimizer.
- Return type
Tuple
[Union
[Tensor
,List
[Tuple
[str
,Tensor
]]],List
[Tuple
[str
,Any
]]]
- get_gradient_penalty(real_data, generated_data)[source]
Based on https://github.com/EmilienDupont/wgan-gp/blob/master/training.py.
- Return type
Tensor
- class ecgan.utils.losses.WassersteinGeneratorLoss(discriminator_sampler, generator_sampler)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
Wasserstein loss for the discriminator.
- class ecgan.utils.losses.AEGANGeneratorLoss(discriminator_sampler, generator_sampler)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
Loss function for the auto-encoder based GANs.
- class ecgan.utils.losses.AEGANDiscriminatorLoss(discriminator_sampler, generator_sampler)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
Discriminator loss for a AEGAN module.
- class ecgan.utils.losses.VAEGANGeneratorLoss(discriminator_sampler, generator_sampler, reconstruction_loss, distribution, kl_beta, device)[source]
Bases:
ecgan.utils.losses.GANBaseLoss
Generator loss for the VAEGAN module.
- class ecgan.utils.losses.SupervisedLossFactory[source]
Bases:
object
Meta module for creating correct loss functions.
- class ecgan.utils.losses.GANLossFactory[source]
Bases:
object
Meta module for creating correct GAN loss functions.
- class ecgan.utils.losses.AutoEncoderLoss(autoencoder_sampler, use_mse)[source]
Bases:
ecgan.utils.configurable.Configurable
Base loss class for custom GAN losses.
- class ecgan.utils.losses.VariationalAutoEncoderLoss(autoencoder_sampler, use_mse, kl_beta, distribution, device)[source]
Bases:
ecgan.utils.configurable.Configurable
Base loss class for custom GAN losses.