Trainer

Basic trainer class used to create data splits, initialize training modules, fit a model and collect metrics.

class ecgan.training.trainer.Trainer(data, label)[source]

Bases: object

Load all required elements (data, config, tracking, plotter) and initialize the model.

Requires a previously set config (via ecgan.config.global_cfg.set_global_config()).

fit()[source]

Fit a model on a split.

Return type

BaseModule

get_split()[source]

Retrieve existing split or create new split with given amount of folds.

Return type

Dict

static prepare_data(train_x, test_x, vali_x, train_y, test_y, vali_y, trainer_cfg, rescale_to_unit_circle)[source]

Select channels, transform data, mask labels and possibly rescale data based on cfg.

Parameters
  • train_x -- Train data.

  • test_x -- Test data.

  • vali_x -- Validation data.

  • train_y -- Train labels.

  • test_y -- Test labels.

  • vali_y -- Validation labels.

  • trainer_cfg (TrainerConfig) -- Trainer configuration.

  • rescale_to_unit_circle (bool) -- Flag indicating if the data shall be rescaled to unit circle.

Return type

Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

Returns

Transformed train, test and vali data and labels.