Trainer
Basic trainer class used to create data splits, initialize training modules, fit a model and collect metrics.
- class ecgan.training.trainer.Trainer(data, label)[source]
Bases:
object
Load all required elements (data, config, tracking, plotter) and initialize the model.
Requires a previously set config (via
ecgan.config.global_cfg.set_global_config()
).- get_split()[source]
Retrieve existing split or create new split with given amount of folds.
- Return type
Dict
- static prepare_data(train_x, test_x, vali_x, train_y, test_y, vali_y, trainer_cfg, rescale_to_unit_circle)[source]
Select channels, transform data, mask labels and possibly rescale data based on cfg.
- Parameters
train_x -- Train data.
test_x -- Test data.
vali_x -- Validation data.
train_y -- Train labels.
test_y -- Test labels.
vali_y -- Validation labels.
trainer_cfg (
TrainerConfig
) -- Trainer configuration.rescale_to_unit_circle (
bool
) -- Flag indicating if the data shall be rescaled to unit circle.
- Return type
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
,Tensor
]- Returns
Transformed train, test and vali data and labels.