NN Classifier
Basic RNN/CNN with the aim to correctly predict labels based on input data.
- class ecgan.modules.classifiers.nn_classifier.NNClassifier(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.base.BaseModule
,ecgan.modules.classifiers.base.BaseClassifier
NN used to predict labels, not used for forecasting which can also be used for AD.
- property classifier: torch.nn.modules.module.Module
Return the NN classifier.
- Return type
Module
- property optimizer: ecgan.utils.optimizer.BaseOptimizer
Return the optimizer for the network.
- Return type
- property criterion: ecgan.utils.losses.SupervisedLoss
Return the criterion for the network.
- Return type
- training_step(batch)[source]
Declare what the model should do during a training step using a given batch.
- Parameters
batch (
dict
) -- The batch of real data.- Return type
Dict
- Returns
A dict containing the optimization metrics which shall be logged.
- validation_step(batch)[source]
Declare what the model should do during a validation step.
- Return type
dict
- load(model_reference, load_optim=False)[source]
Load a trained module from disk (file path) or wand reference.
- property watch_list: List[torch.nn.modules.module.Module]
Return models that should be watched during training.
- Return type
List
[Module
]
- class ecgan.modules.classifiers.nn_classifier.CNNClassifier(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.classifiers.nn_classifier.NNClassifier
Argmax CNN classifier.
- class ecgan.modules.classifiers.nn_classifier.RNNClassifier(cfg, seq_len, num_channels)[source]
Bases:
ecgan.modules.classifiers.nn_classifier.NNClassifier
Argmax RNN classifier.