Base classifier

Definition of a base PyTorch classifier.

class ecgan.modules.classifiers.base.BaseClassifier[source]

Bases: abc.ABC

Abstract baseclass for classification models.

Each classifier is expected to handle queries of incoming data.

abstract classify(data)[source]

Return a classification score.

Return type

Tensor

static get_classification_metrics(real_label, prediction_labels, stage='metrics/', get_fscore_weighted=False, get_fscore_micro=False, get_prec_recall_fscore=True, get_accuracy=True, get_mcc=True, get_auroc=True)[source]

Compute classification metrics for given input data and prediction.

Parameters
  • real_label (ndarray) -- Real input label (y).

  • prediction_labels (ndarray) -- Predicted label (y_hat).

  • stage (str) -- String identifier for the logging stage.

  • get_prec_recall_fscore -- Flag to indicate if precision, recall, F-score (macro) and/or support are computed.

  • get_fscore_weighted (bool) -- Flag to indicate if the weighted F-score should be computed.

  • get_fscore_micro (bool) -- Flag to indicate if the micro F-score should be computed.

  • get_accuracy (bool) -- Flag indicating if the accuracy should be computed.

  • get_mcc (bool) -- Flag indicating if the MCC should be computed.

  • get_auroc (bool) -- Flag indicating if the AUROC should be computed.

Return type

Dict

Returns

Dict containing all metrics that were marked as to be computed.