Source code for ecgan.networks.beatgan

"""BeatGAN encoder, generator and discriminator from Zhou et al. 2019."""
from typing import Dict, List

from torch import nn

from ecgan.networks.helpers import create_5_hidden_layer_convnet, create_transpose_conv_net
from ecgan.utils.configurable import ConfigurableTorchModule
from ecgan.utils.custom_types import InputNormalization, WeightInitialization
from ecgan.utils.losses import AEGANDiscriminatorLoss, BceGeneratorLoss, L2Loss
from ecgan.utils.optimizer import Adam


[docs]class BeatganInverseEncoder(ConfigurableTorchModule): """Encoder of the BeatGAN model.""" def __init__( self, input_channels: int, hidden_channels: List[int], output_channels: int, seq_len: int, input_norm: InputNormalization, spectral_norm: bool, ): super().__init__() self.net = create_5_hidden_layer_convnet( input_channels, hidden_channels, output_channels, seq_len, input_norm=input_norm, spectral_norm=spectral_norm, track_running_stats=True, )
[docs] def forward(self, x): """Perform a forward pass.""" x = x.permute(0, 2, 1) x = self.net(x) x = x.permute(0, 2, 1) return x
[docs] @staticmethod def configure() -> Dict: """Return the default configuration for the encoder of the BeatGAN module.""" config = { 'ENCODER': { 'LAYER_SPECIFICATION': { 'HIDDEN_CHANNELS': [32, 64, 128, 256, 512], }, 'INPUT_NORMALIZATION': InputNormalization.BATCH.value, 'SPECTRAL_NORM': False, 'WEIGHT_INIT': { 'NAME': WeightInitialization.GLOROT_NORMAL.value, }, } } config['ENCODER'].update(L2Loss.configure()) config['ENCODER'].update(Adam.configure()) return config
[docs]class BeatganDiscriminator(ConfigurableTorchModule): """Discriminator of the BeatGAN model.""" def __init__( self, input_channels: int, hidden_channels: List[int], output_channels: int, seq_len: int, input_norm: InputNormalization, spectral_norm: bool, ): super().__init__() model: nn.Module = create_5_hidden_layer_convnet( input_channels=input_channels, hidden_channels=hidden_channels, output_channels=output_channels, seq_len=seq_len, input_norm=input_norm, spectral_norm=spectral_norm, track_running_stats=True, ) layers = list(model.children()) self.features = nn.Sequential(*layers[:-1]) self.classifier = nn.Sequential(layers[-1]) self.classifier.add_module('Sigmoid', nn.Sigmoid())
[docs] def forward(self, x): """Perform a forward pass.""" x = x.permute(0, 2, 1) features = self.features(x) classifier = self.classifier(features).view(-1, 1).squeeze(1) features.permute(0, 2, 1) return classifier, features
[docs] @staticmethod def configure() -> Dict: """Return the default configuration for the discriminator of the BeatGAN model.""" config = { 'DISCRIMINATOR': { 'SPECTRAL_NORM': True, 'LAYER_SPECIFICATION': { 'HIDDEN_CHANNELS': [32, 64, 128, 256, 512], }, 'INPUT_NORMALIZATION': InputNormalization.NONE.value, 'WEIGHT_INIT': { 'NAME': WeightInitialization.GLOROT_NORMAL.value, }, } } config['DISCRIMINATOR'].update(AEGANDiscriminatorLoss.configure()) config['DISCRIMINATOR'].update(Adam.configure()) config['DISCRIMINATOR']['OPTIMIZER']['BETAS'] = [0.5, 0.999] # type: ignore return config
[docs]class BeatganGenerator(ConfigurableTorchModule): """Generator of the BeatGAN model.""" def __init__( self, input_channels: int, hidden_channels: List[int], latent_size: int, seq_len: int, input_norm: InputNormalization, spectral_norm: bool, tanh_out: bool, ): super().__init__() self.model = create_transpose_conv_net( input_channels=latent_size, hidden_channels=hidden_channels, output_channels=input_channels, seq_len=seq_len, input_norm=input_norm, spectral_norm=spectral_norm, track_running_stats=True, ) if tanh_out: self.model.add_module('Tanh', nn.Tanh()) else: self.model.add_module('Sigmoid', nn.Sigmoid())
[docs] def forward(self, x): """Perform a forward pass.""" x = x.permute(0, 2, 1) x = self.model(x) x = x.permute(0, 2, 1) return x
[docs] @staticmethod def configure() -> Dict: """Return the default configuration for the generator of the BeatGAN module.""" config = { 'GENERATOR': { 'LAYER_SPECIFICATION': { 'HIDDEN_CHANNELS': [512, 256, 128, 64, 32], }, 'TANH_OUT': True, 'INPUT_NORMALIZATION': InputNormalization.BATCH.value, 'SPECTRAL_NORM': False, 'WEIGHT_INIT': { 'NAME': WeightInitialization.GLOROT_NORMAL.value, }, } } config['GENERATOR'].update(BceGeneratorLoss.configure()) config['GENERATOR'].update(Adam.configure()) config['GENERATOR']['OPTIMIZER']['BETAS'] = [0.5, 0.999] # type: ignore return config