NN Classifier

Basic RNN/CNN with the aim to correctly predict labels based on input data.

class ecgan.modules.classifiers.nn_classifier.NNClassifier(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.base.BaseModule, ecgan.modules.classifiers.base.BaseClassifier

NN used to predict labels, not used for forecasting which can also be used for AD.

property classifier: torch.nn.modules.module.Module

Return the NN classifier.

Return type

Module

classify(data)[source]

Return a classification score according to the NN.

Return type

Tensor

property optimizer: ecgan.utils.optimizer.BaseOptimizer

Return the optimizer for the network.

Return type

BaseOptimizer

property criterion: ecgan.utils.losses.SupervisedLoss

Return the criterion for the network.

Return type

SupervisedLoss

training_step(batch)[source]

Declare what the model should do during a training step using a given batch.

Parameters

batch (dict) -- The batch of real data.

Return type

Dict

Returns

A dict containing the optimization metrics which shall be logged.

validation_step(batch)[source]

Declare what the model should do during a validation step.

Return type

dict

save_checkpoint()[source]

Return current model parameters.

Return type

dict

load(model_reference, load_optim=False)[source]

Load a trained module from disk (file path) or wand reference.

property watch_list: List[torch.nn.modules.module.Module]

Return models that should be watched during training.

Return type

List[Module]

on_epoch_end(epoch, sample_interval, batch_size)[source]

Set actions to be executed after epoch ends.

Declare what should be done upon finishing an epoch (e.g. save artifacts or evaluate some metric).

Return type

List[Artifact]

class ecgan.modules.classifiers.nn_classifier.CNNClassifier(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.classifiers.nn_classifier.NNClassifier

Argmax CNN classifier.

static configure()[source]

Return the default configuration of a standard CNN classifier.

Return type

Dict

class ecgan.modules.classifiers.nn_classifier.RNNClassifier(cfg, seq_len, num_channels)[source]

Bases: ecgan.modules.classifiers.nn_classifier.NNClassifier

Argmax RNN classifier.

static configure()[source]

Return the default configuration of a standard RNN classifier.

Return type

Dict