"""
Calculates metrics of the models and houses list of the available metrics to use.
"""
import torch
import torchmetrics as tm
[docs]
class MetricsCalculator:
"""
Calculates metrics of the models and houses list of the available metrics to use.
"""
[docs]
def __init__(self, base_model_instance):
"""
Parameters
----------
base_model_instance : fusilli.fusionmodels.base_model.BaseModel
Instance of the base model. Has information on the prediction task and multiclass dimensions if applicable.
"""
self.model = base_model_instance
self.prediction_task = base_model_instance.model.prediction_task
[docs]
def auroc(self, preds, labels, logits):
"""
Area under the receiver operating characteristic curve.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
AUROC value.
"""
if self.prediction_task == "binary":
auroc_equation = tm.AUROC(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
auroc_equation = tm.AUROC(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for AUROC.")
return auroc_equation(logits, labels)
[docs]
def accuracy(self, preds, labels, logits):
"""
Calculates accuracy.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
Accuracy value.
"""
if self.prediction_task == "binary":
# do binary accuracy
accuracy_equation = tm.Accuracy(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
# do multiclass accuracy
accuracy_equation = tm.Accuracy(
num_classes=self.model.multiclass_dimensions, task="multiclass", top_k=1
).to(preds.device)
else:
raise ValueError("Invalid prediction task for accuracy.")
return accuracy_equation(preds, labels)
[docs]
def r2(self, preds, labels, logits):
"""
Calculates R2 score.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
R2 score value.
"""
if self.prediction_task != "regression":
raise ValueError("Invalid prediction task for R2.")
return tm.R2Score().to(preds.device)(preds, labels)
[docs]
def mse(self, preds, labels, logits):
"""
Calculates mean squared error.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
MSE value.
"""
if self.prediction_task != "regression":
raise ValueError("Invalid prediction task for mse.")
return tm.MeanSquaredError().to(preds.device)(preds, labels)
[docs]
def mae(self, preds, labels, logits):
"""
Calculates mean absolute error.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
MAE value.
"""
if self.prediction_task != "regression":
raise ValueError("Invalid prediction task for mae.")
return tm.MeanAbsoluteError().to(preds.device)(preds, labels)
[docs]
def recall(self, preds, labels, logits):
"""
Calculates recall. This is equivalent to sensitivity.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
Recall value.
"""
if self.prediction_task == "binary":
recall_equation = tm.Recall(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
recall_equation = tm.Recall(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for recall.")
return recall_equation(preds, labels)
[docs]
def specificity(self, preds, labels, logits):
"""
Calculates specificity.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
Specificity value.
"""
if self.prediction_task == "binary":
specificity_equation = tm.Specificity(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
specificity_equation = tm.Specificity(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for specificity.")
return specificity_equation(preds, labels)
[docs]
def precision(self, preds, labels, logits):
"""
Calculates precision.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
Precision value.
"""
if self.prediction_task == "binary":
precision_equation = tm.Precision(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
precision_equation = tm.Precision(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for precision.")
return precision_equation(preds, labels)
[docs]
def f1(self, preds, labels, logits):
"""
Calculates F1 score. This is equivalent to the Dice coefficient.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
F1 score value.
"""
if self.prediction_task == "binary":
f1_equation = tm.F1Score(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
f1_equation = tm.F1Score(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for F1.")
return f1_equation(preds, labels)
[docs]
def auprc(self, preds, labels, logits):
"""
Calculates area under the precision-recall curve.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
AUPRC value.
"""
if self.prediction_task == "binary":
auprc_equation = tm.AveragePrecision(task="binary").to(preds.device)
elif self.prediction_task == "multiclass":
auprc_equation = tm.AveragePrecision(
num_classes=self.model.multiclass_dimensions, task="multiclass"
).to(preds.device)
else:
raise ValueError("Invalid prediction task for AUPRC.")
return auprc_equation(logits, labels)
[docs]
def balanced_accuracy(self, preds, labels, logits):
"""
Calculates balanced accuracy.
Parameters
----------
preds : torch.Tensor
Predicted values from the model.
labels : torch.Tensor
True labels.
logits : torch.Tensor
Probability values from the model.
Returns
-------
float
Balanced accuracy value.
"""
if self.prediction_task == "binary":
balanced_accuracy_equation = tm.Accuracy(
task="multiclass", num_classes=2, average="macro"
).to(preds.device)
elif self.prediction_task == "multiclass":
balanced_accuracy_equation = tm.Accuracy(
task="multiclass",
num_classes=self.model.multiclass_dimensions,
average="macro",
).to(preds.device)
else:
raise ValueError("Invalid prediction task for balanced accuracy.")
return balanced_accuracy_equation(preds, labels)