Base Module

Abstract base module for learning algorithms which use training and validation steps.

class ecgan.modules.base.BaseModule(cfg, seq_len, num_channels)[source]

Bases: ecgan.utils.configurable.Configurable

Base class from which all implemented modules should inherit.

abstract training_step(batch)[source]

Declare what the model should do during a training step using a given batch.

The returned metrics are concatenated across batches and averaged before logging. Note: this is important if you want to log min/max values!

Parameters

batch (dict) -- A batch of data.

Return type

dict

Returns

A dict containing the metrics from optimization or evaluation which shall be logged.

abstract validation_step(batch)[source]

Declare what the model should do during a validation step.

Parameters

batch (dict) -- A batch of data.

Return type

dict

Returns

A dict containing the metrics from optimization or evaluation which shall be logged.

abstract save_checkpoint()[source]

Return current model parameters.

Return type

dict

abstract load(model_reference, load_optim=False)[source]

Load a trained module from disk (file path) or wand reference.

Return type

BaseModule

abstract property watch_list: List

Return torch nn.Modules that should be watched during training.

Return type

List

classmethod print_metric(epoch, metrics)[source]

Allow overwriting the static _print_metric method.

Return type

None

abstract on_epoch_end(epoch, sample_interval, batch_size)[source]

Set actions to be executed after epoch ends.

Declare what should be done upon finishing an epoch (e.g. save artifacts or evaluate some metric).

Parameters
  • epoch (int) -- Current training epoch.

  • sample_interval (int) -- Regular sampling interval to save modules independent of their performance.

  • batch_size (int) -- Size of batch.

Return type

List[Artifact]

Returns

List containing all ecgan.utils.artifacts.Artifact s which shall be logged upon epoch end.