Base classifier
Definition of a base PyTorch classifier.
- class ecgan.modules.classifiers.base.BaseClassifier[source]
Bases:
abc.ABC
Abstract baseclass for classification models.
Each classifier is expected to handle queries of incoming data.
- static get_classification_metrics(real_label, prediction_labels, stage='metrics/', get_fscore_weighted=False, get_fscore_micro=False, get_prec_recall_fscore=True, get_accuracy=True, get_mcc=True, get_auroc=True)[source]
Compute classification metrics for given input data and prediction.
- Parameters
real_label (
ndarray
) -- Real input label (y).prediction_labels (
ndarray
) -- Predicted label (y_hat).stage (
str
) -- String identifier for the logging stage.get_prec_recall_fscore -- Flag to indicate if precision, recall, F-score (macro) and/or support are computed.
get_fscore_weighted (
bool
) -- Flag to indicate if the weighted F-score should be computed.get_fscore_micro (
bool
) -- Flag to indicate if the micro F-score should be computed.get_accuracy (
bool
) -- Flag indicating if the accuracy should be computed.get_mcc (
bool
) -- Flag indicating if the MCC should be computed.get_auroc (
bool
) -- Flag indicating if the AUROC should be computed.
- Return type
Dict
- Returns
Dict containing all metrics that were marked as to be computed.