"""
This module contains classes and functions for evaluating trained models (i.e. plotting results from training).
The setup for this module has been inspired by the scikit-learn API for plotting results, which involves each plot
being a class with a ``from_final_val_data`` method that takes in a trained model and returns a plot with the validation data,
and a ``from_new_data`` method that takes in a trained model and new data and returns a plot.
"""
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import gridspec
from sklearn.metrics import confusion_matrix
from torch.utils.data import ConcatDataset, DataLoader
import fusilli.data as data
from fusilli.fusionmodels.base_model import BaseModel
from fusilli.utils.training_utils import (
get_checkpoint_filename_for_trained_fusion_model,
)
from fusilli.utils import model_modifier
[docs]
class ParentPlotter:
"""Parent class for all plot classes.
It includes methods that are used by multiple plot classes, such as obtaining final
validation data from kfold and train/test models, and putting new data through the
models for both kfold and train/test protocols.
"""
[docs]
def __init__(self):
pass
[docs]
@classmethod
def get_kfold_data_from_model(cls, model_list):
"""
Get the final validation data from a kfold model, meaning the data that was used to evaluate the model
when the model training was complete.
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2,
where the first element is the k=1 model and the second element is the k=2 model, etc.
Returns
-------
train_reals: list
List of torch.Tensors of the real values for the training set for each fold.
This is stored in each trained model's class instance.
train_preds: list
List of torch.Tensors of the predicted values for the training set for each fold.
This is stored in each trained model's class instance.
val_reals: list
List of torch.Tensors of the real values for the validation set for each fold.
This is stored in each trained model's class instance.
val_preds: list
List of torch.Tensors of the predicted values for the validation set for each fold.
This is stored in each trained model's class instance.
metrics_per_fold: dict
Dictionary of the metrics for each fold.
The keys are the names of the metrics and the values are lists of the metric values for each fold.
overall_kfold_metrics: dict
Dictionary of the overall kfold metrics.
The keys are the names of the metrics and the values are the metric values for the overall kfold,
meaning the metric values for the concatenated final validation data over all the folds.
"""
train_reals = []
train_preds = []
val_reals = []
val_preds = []
val_logits = []
metric_names = list(model_list[0].metrics.keys())
metrics_per_fold = {}
for metric_name in metric_names:
metrics_per_fold[metric_name.lower()] = []
# loop through the folds
for fold in model_list: # 0 is the model, 1 is the ckpt path
# get the data points
train_reals.append(fold.train_reals.cpu())
train_preds.append(fold.train_preds.cpu())
val_reals.append(fold.val_reals.cpu())
val_preds.append(fold.val_preds.cpu())
val_logits.append(fold.val_logits.cpu())
# get the metrics
for i, metric in enumerate(fold.final_val_metrics):
metrics_per_fold[metric_names[i].lower()].append(metric)
# concatenate the validation data points for the overall kfold performance
all_val_reals = torch.cat(val_reals, dim=-1)
all_val_preds = torch.cat(val_preds, dim=-1)
all_val_logits = torch.cat(val_logits, dim=0)
# get the overall kfold metrics
overall_kfold_metrics = {}
for metric_name, metric_func in model_list[0].metrics.items():
val_metric = metric_func(
preds=model_list[0].safe_squeeze(all_val_preds),
labels=model_list[0].safe_squeeze(all_val_reals),
logits=model_list[0].safe_squeeze(all_val_logits),
)
overall_kfold_metrics[metric_name.lower()] = val_metric.cpu().detach().item()
return (
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
[docs]
@classmethod
def get_tt_data_from_model(cls, model_list):
"""
Get the final validation data from a train/test model, meaning the data that was used to evaluate the model
when the model training was complete.
Parameters
----------
model_list: list
A list of length 1 containing the trained pytorch_lightning model.
Returns
-------
train_reals: torch.Tensor
Torch.Tensor of the real values for the training set.
This is stored in the trained model's class instance.
train_preds: torch.Tensor
Torch.Tensor of the predicted values for the training set.
This is stored in the trained model's class instance.
val_reals: torch.Tensor
Torch.Tensor of the real values for the validation set.
This is stored in the trained model's class instance.
val_preds: torch.Tensor
Torch.Tensor of the predicted values for the validation set.
This is stored in the trained model's class instance.
metric_values: dict
Dictionary of the metrics for the model.
The keys are the names of the metrics and the values are the metric values for the model.
"""
model = model_list[0] # get the model from the list of length 1
# not training the model
model.eval()
# data points
train_reals = model.train_reals.cpu()
train_preds = model.train_preds.cpu()
val_reals = model.val_reals.cpu()
val_preds = model.val_preds.cpu()
# metrics
metric_values = {}
for i, metric in enumerate(model.metrics):
metric_values[metric] = model.final_val_metrics[i]
return train_reals, train_preds, val_reals, val_preds, metric_values
[docs]
@classmethod
def get_new_kfold_data(
cls, model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None
):
"""
Get new data by running through trained model for a kfold model.
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2,
where the first element is the k=1 model and the second element is the k=2 model, etc.
output_paths: dict
Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.
test_data_paths: dict
Dictionary of the paths to the new data. The keys are the names of the data types (e.g. "tabular1", "image").
checkpoint_file_suffix: str, optional
Suffix that is on the trained model checkpoint files. e.g. "_firsttry". Added by the user.
Default is None.
layer_mods: dict, optional
Dictionary of the layer modifications to make to the model.
Returns
-------
train_reals: list
List of torch.Tensors of the real values for the training set for each fold.
This is stored in each trained model's class instance.
train_preds: list
List of torch.Tensors of the predicted values for the training set for each fold.
This is stored in each trained model's class instance.
val_reals: list
List of torch.Tensors of the real values for the new data set for each fold.
val_preds: list
List of torch.Tensors of the predicted values for the new data set for each fold.
This was obtained by running the new data through the trained model.
metrics_per_fold: dict
Dictionary of the metrics for each fold.
The keys are the names of the metrics and the values are lists of the metric values for each fold.
overall_kfold_metrics: dict
Dictionary of the overall kfold metrics.
The keys are the names of the metrics and the values are the metric values for the overall kfold,
meaning the metric values for the concatenated new data over all the folds.
Raises
------
ValueError
If the model has a graph maker, it's not supported yet for creating graphs from new data.
"""
if checkpoint_file_suffix is None:
checkpoint_file_suffix = ""
train_reals = []
train_preds = []
val_reals = []
val_preds = []
val_logits = []
metric_names = list(model_list[0].metrics.keys())
# dictionary to store the metrics for each fold
metrics_per_fold = {}
for metric_name in metric_names:
metrics_per_fold[metric_name.lower()] = []
output_paths_copy = output_paths.copy()
num_folds = len(model_list)
# loop through the folds and get the predictions for each fold
for k, fold_model in enumerate(model_list):
# eval the model
model = fold_model
# ckpt_path = fold_model[1]
model.eval()
if hasattr(model.model, "graph_maker"):
raise ValueError(
"Model has a graph maker. This is not supported yet for creating graphs from new data."
)
if model.model.subspace_method is not None:
subspace_ckpts = []
for subspace_model in model.model.subspace_method.subspace_models:
subspace_ckpts.append(
output_paths["checkpoints"]
+ "/"
+ "subspace_"
+ model.model.__class__.__name__
+ "_"
+ subspace_model.__name__
+ "_fold_"
+ str(k)
+ checkpoint_file_suffix
+ ".ckpt"
)
else:
subspace_ckpts = None
dm = data.prepare_fusion_data(
prediction_task=model.model.prediction_task,
fusion_model=model.model,
data_paths=test_data_paths,
output_paths=output_paths_copy,
kfold=True,
num_folds=num_folds,
checkpoint_path=subspace_ckpts,
layer_mods=layer_mods,
)
# just taking the first fold because we don't need to split the new data into folds
# we just wanted to convert it to latent using that fold's trained subspace model
dm.train_dataset = dm.folds[0][0]
dm.test_dataset = dm.folds[0][1]
dataset = ConcatDataset([dm.train_dataset, dm.test_dataset])
dataloader = DataLoader(dataset, batch_size=len(dataset))
trained_fusion_model_checkpoint = (
get_checkpoint_filename_for_trained_fusion_model(
checkpoint_dir=output_paths["checkpoints"],
model=model,
checkpoint_file_suffix=checkpoint_file_suffix,
fold=k
)
)
# init model
new_model = BaseModel(
model=model.model.__class__(
prediction_task=model.model.prediction_task,
data_dims=dm.data_dims, # data_dims is a list of tuples
multiclass_dimensions=dm.multiclass_dimensions,
),
metrics_list=model.metrics_list,
)
# modify layers if needed
if layer_mods is not None:
new_model.model = model_modifier.modify_model_architecture(
new_model.model,
layer_mods,
)
# load the state dict
new_model.load_state_dict(torch.load(trained_fusion_model_checkpoint)["state_dict"])
new_model.eval()
fold_val_preds = []
fold_val_logits = []
fold_val_reals = []
for batch in dataloader:
x, y = new_model.get_data_from_batch(batch)
out = new_model.get_model_outputs_and_loss(x, y)
loss, end_output, logits = out
fold_val_preds.append(end_output.cpu().detach())
fold_val_logits.append(logits.cpu().detach())
fold_val_reals.append(y.cpu().detach())
fold_val_reals = torch.cat(fold_val_reals, dim=-1)
fold_val_preds = torch.cat(fold_val_preds, dim=-1)
fold_val_logits = torch.cat(fold_val_logits, dim=0)
val_reals.append(fold_val_reals)
val_preds.append(fold_val_preds)
val_logits.append(fold_val_logits)
# training data points from the old trained BaseModel
train_reals.append(model.train_reals.cpu().detach())
train_preds.append(model.train_preds.cpu().detach())
for metric_name, metric_func in new_model.metrics.items():
val_step_metric = metric_func(
preds=model_list[0].safe_squeeze(fold_val_preds),
labels=model_list[0].safe_squeeze(fold_val_reals),
logits=model_list[0].safe_squeeze(fold_val_logits),
)
metrics_per_fold[metric_name.lower()].append(val_step_metric)
# concatenate the validation data points for the overall kfold performance
all_val_reals = torch.cat(val_reals, dim=-1)
all_val_preds = torch.cat(val_preds, dim=-1)
all_val_logits = torch.cat(val_logits, dim=0)
# get the overall kfold metrics
overall_kfold_metrics = {}
for metric_name, metric_func in new_model.metrics.items():
val_metric = metric_func(
preds=model_list[0].safe_squeeze(all_val_preds),
labels=model_list[0].safe_squeeze(all_val_reals),
logits=model_list[0].safe_squeeze(all_val_logits),
)
overall_kfold_metrics[metric_name.lower()] = val_metric.cpu().detach().item()
return (
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
[docs]
@classmethod
def get_new_tt_data(
cls, model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None
):
"""
Get new data by running through trained model for a train/test model.
Parameters
----------
model_list: list
A list of length 1 containing the trained pytorch_lightning model.
output_paths: dict
Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.
test_data_paths: dict
Dictionary of the paths to the new data. The keys are the names of the data types (e.g. "tabular1", "image").
checkpoint_file_suffix: str, optional
Suffix that is on the trained model checkpoint files. e.g. "_firsttry". Added by the user.
Default is None.
layer_mods: dict, optional
Dictionary of the layer modifications to make to the model.
Returns
-------
train_reals: torch.Tensor
Torch.Tensor of the real values for the training set.
This is stored in the trained model's class instance.
train_preds: torch.Tensor
Torch.Tensor of the predicted values for the training set.
This is stored in the trained model's class instance.
val_reals: torch.Tensor
Torch.Tensor of the real values for the new data set.
val_preds: torch.Tensor
Torch.Tensor of the predicted values for the new data set.
metric_values: dict
Dictionary of the metrics for the model.
The keys are the names of the metrics and the values are the metric values for the model.
Raises
------
ValueError
If the model has a graph maker, it's not supported yet for creating graphs from new data.
"""
if checkpoint_file_suffix is None:
checkpoint_file_suffix = ""
# eval the model
# ckpt_path = model[0][1]
model = model_list[0]
model.eval()
if hasattr(model.model, "graph_maker"):
raise ValueError(
"Model has a graph maker. This is not supported yet for creating graphs from new data."
)
if model.model.subspace_method is not None:
subspace_ckpts = []
for subspace_model in model.model.subspace_method.subspace_models:
subspace_ckpts.append(
output_paths["checkpoints"]
+ "/"
+ "subspace_"
+ model.model.__class__.__name__
+ "_"
+ subspace_model.__name__
+ checkpoint_file_suffix
+ ".ckpt"
)
else:
subspace_ckpts = None
# get data module (potentially will need to be trained with a subspace method or graph-maker)
dm = data.prepare_fusion_data(
prediction_task=model.model.prediction_task,
fusion_model=model.model.__class__,
data_paths=test_data_paths,
output_paths=output_paths,
checkpoint_path=subspace_ckpts,
layer_mods=layer_mods,
)
# concatenating the train and test datasets because we want to get the predictions for all the data
dataset = ConcatDataset([dm.train_dataset, dm.test_dataset])
dataloader = DataLoader(dataset, batch_size=len(dataset))
# get ckpt_path from fusion name
trained_fusion_model_checkpoint = (
get_checkpoint_filename_for_trained_fusion_model(
output_paths["checkpoints"], model, checkpoint_file_suffix, fold=None
)
)
# init model
new_model = BaseModel(
model=model.model.__class__(
prediction_task=model.model.prediction_task,
# prediction_task is a string (binary, regression, multiclass)
data_dims=dm.data_dims, # data_dims is a list of tuples
multiclass_dimensions=dm.multiclass_dimensions,
),
metrics_list=model.metrics_list,
)
# modify layers if needed
if layer_mods is not None:
new_model.model = model_modifier.modify_model_architecture(
new_model.model,
layer_mods,
)
# load the state dict
new_model.load_state_dict(torch.load(trained_fusion_model_checkpoint)["state_dict"])
new_model.eval()
# get the predictions
end_outputs_list = []
logits_list = []
reals_list = []
for batch in dataloader:
x, y = new_model.get_data_from_batch(batch)
out = new_model.get_model_outputs_and_loss(x, y)
loss, end_output, logits = out
end_outputs_list.append(new_model.safe_squeeze(end_output).cpu().detach())
logits_list.append(new_model.safe_squeeze(logits).cpu().detach())
reals_list.append(new_model.safe_squeeze(y).cpu().detach())
# get the train reals, train preds, val reals, val preds
train_reals = model.train_reals.cpu()
train_preds = model.train_preds.cpu()
val_preds = torch.cat(end_outputs_list, dim=-1)
val_reals = torch.cat(reals_list, dim=-1)
val_logits = torch.cat(logits_list, dim=0)
# get the metrics
metric_values = {}
for metric_name, metric_func in new_model.metrics.items():
val_step_metric = metric_func(
preds=new_model.safe_squeeze(val_preds),
labels=new_model.safe_squeeze(val_reals),
logits=new_model.safe_squeeze(val_logits),
)
metric_values[metric_name.lower()] = val_step_metric
return train_reals, train_preds, val_reals, val_preds, metric_values
[docs]
class RealsVsPreds(ParentPlotter):
"""
Plots the real values vs the predicted values for a model.
Pink dots are the training data, green dots are the validation data. The validation data is
either new data if using from_new_data or the original validation data if using from_final_val_data.
"""
[docs]
def __init__(self):
super().__init__()
[docs]
@classmethod
def from_new_data(
cls, model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None
):
"""
Reals vs preds plot using new data (i.e. data that was not used to train or validate the model).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2,
where the first element is the k=1 model and the second element is the k=2 model, etc.
For train/test models, this is a list of length 1.
output_paths : dict
Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.
test_data_paths: dict
Dictionary of the paths to the new data. The keys are the names of the data types (e.g. "tabular1", "image").
checkpoint_file_suffix: str, optional
Suffix that is on the trained model checkpoint files. e.g. "_firsttry". Added by the user.
Default is None.
layer_mods: dict, optional
Dictionary of the layer modifications to make to the model.
Returns
-------
figure: matplotlib.pyplot.figure
The figure of the plot.
Raises
------
ValueError
If the model is not a list.
If the model is a list of length > 1 but kfold_flag is False.
If the model is a list of length 1 but kfold_flag is True.
If the model is an empty list.
"""
if not isinstance(model_list, list):
raise ValueError(
(
"Argument 'model_list' is not a list. "
"Please check the model and the function input."
"If you are using a train/test model, the single model must be in a list of length 1."
)
)
if len(model_list) > 1:
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_new_kfold_data(
model_list, output_paths, test_data_paths, checkpoint_file_suffix, layer_mods
)
figure = cls.reals_vs_preds_kfold(
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
figure.suptitle("Evaluation: External Test Data")
elif len(model_list) == 1:
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_new_tt_data(
model_list, output_paths, test_data_paths, checkpoint_file_suffix, layer_mods
)
# plot the figure
figure = cls.reals_vs_preds_tt(
model_list,
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
)
figure.suptitle("Evaluation: External Test Data")
else:
raise ValueError("Argument 'model_list' is an empty list. ")
return figure
[docs]
@classmethod
def from_final_val_data(cls, model_list):
"""
Reals vs preds plot using the final validation data (i.e. the data that was used to evaluate the model
when the model training was complete).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first
element is the k=1 model and the second element is the k=2 model, etc.
For train/test models, this is a list of length 1.
Returns
-------
figure: matplotlib.pyplot.figure
The figure of the plot.
Raises
------
ValueError
If the model is not a list.
If the model is a list of length > 1 but kfold_flag is False.
If the model is a list of length 1 but kfold_flag is True.
If the model is an empty list.
"""
if not isinstance(model_list, list):
raise ValueError(
(
"Argument 'model_list' is not a list. "
"Please check the model and the function input."
"If you are using a train/test model, the single model must be in a list of length 1."
)
)
if len(model_list) > 1: # kfold model (list of models and their checkpoints)
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_kfold_data_from_model(model_list)
figure = cls.reals_vs_preds_kfold(
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
figure.suptitle("Evaluation: Validation Data")
elif len(model_list) == 1:
# get the data
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_tt_data_from_model(model_list)
# plot the figure
figure = cls.reals_vs_preds_tt(
model_list,
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
)
figure.suptitle("Evaluation: Validation Data")
else:
raise ValueError(("Argument 'model_list' is an empty list. "))
return figure
[docs]
@classmethod
def reals_vs_preds_kfold(
cls,
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
):
"""
Reals vs preds plot for a kfold model. This function should be called within the RealVsPreds class
after the k-fold data has been obtained from the model (either old data or new data).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first
element is the k=1 model and the second element is the k=2 model, etc.
val_reals: list
List of torch.Tensors of the real values for the new data set for each fold.
val_preds: list
List of torch.Tensors of the predicted values for the new data set for each fold.
metrics_per_fold: dict
Dictionary of the metrics for each fold.
The keys are the names of the metrics and the values are lists of the metric values for each fold.
overall_kfold_metrics: dict
Dictionary of the overall kfold metrics.
The keys are the names of the metrics and the values are the metric values for the overall kfold,
meaning the metric values for the concatenated new data over all the folds.
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
"""
first_fold_model = model_list[0]
metric_names = list(metrics_per_fold.keys())
N = len(model_list)
cols = 3
rows = int(math.ceil(N / cols))
fig = plt.figure(constrained_layout=True, figsize=(15, 6))
subplots = fig.subfigures(1, 2)
ax0 = subplots[0].subplots(1, 1)
gs = gridspec.GridSpec(
rows,
cols,
hspace=0.5,
wspace=0.7,
)
for n in range(N):
if n == 0:
ax1 = subplots[1].add_subplot(gs[n])
ax_og = ax1
else:
ax1 = subplots[1].add_subplot(gs[n], sharey=ax_og, sharex=ax_og)
# get real and predicted values for the current fold
reals = val_reals[n]
preds = val_preds[n]
# plot real vs. predicted values
ax1.scatter(reals, preds, c="#f082ef", marker="o")
# plot x=y line as a dashed line
ax1.plot(
[0, 1],
[0, 1],
color="k",
linestyle="--",
alpha=0.75,
zorder=0,
transform=ax1.transAxes,
)
# set title of plot to the metric for the current fold
ax1.set_title(
f"Fold {n + 1}: {metric_names[0]}={float(metrics_per_fold[metric_names[0].lower()][n]):.3f}"
)
# set x and y labels
ax1.set_xlabel("Real Values")
ax1.set_ylabel("Predictions")
all_val_reals = torch.cat(val_reals, dim=-1)
all_val_preds = torch.cat(val_preds, dim=-1)
# plot all real vs. predicted values
ax0.scatter(all_val_reals, all_val_preds, c="#f082ef", marker="o")
# plot x=y line as a dashed line
ax0.plot(
[0, 1],
[0, 1],
color="k",
linestyle="--",
alpha=0.75,
zorder=0,
transform=ax0.transAxes,
)
ax0.set_title(
(
f"{first_fold_model.model.method_name}: {metric_names[0]}"
f"={float(overall_kfold_metrics[metric_names[0]]):.3f}"
)
)
# set x and y labels
ax0.set_xlabel("Real Values")
ax0.set_ylabel("Predictions")
# Set the overall title for the entire figure
fig.suptitle(
f"{first_fold_model.model.__class__.__name__}: reals vs. predicteds"
)
return fig
[docs]
@classmethod
def reals_vs_preds_tt(
cls, model_list, train_reals, train_preds, val_reals, val_preds, metric_values
):
"""
Reals vs preds plot for a train/test model. This function should be called within the RealVsPreds class
after the train/test data has been obtained from the model (either old data or new data).
Parameters
----------
model_list: list
A list of length 1 containing the trained pytorch_lightning model.
train_reals: torch.Tensor
Torch.Tensor of the real values for the training set.
train_preds: torch.Tensor
Torch.Tensor of the predicted values for the training set.
val_reals: torch.Tensor
Torch.Tensor of the real values for the new data set.
val_preds: torch.Tensor
Torch.Tensor of the predicted values for the new data set.
metric_values: dict
Dictionary of the metrics for the model.
The keys are the names of the metrics and the values are the metric values for the model.
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
"""
model = model_list[0]
fig, ax = plt.subplots()
ax.scatter(
train_reals,
train_preds,
c="#f082ef",
marker="o",
label="Train",
)
ax.scatter(val_reals, val_preds, c="#00b64e", marker="^", label="Validation")
# Get the limits of the current scatter plot
x_min, x_max = plt.xlim()
y_min, y_max = plt.ylim()
# Set up data points for the x=y line
line_x = np.linspace(min(x_min, y_min), max(x_max, y_max), 100)
line_y = line_x
# Plot the x=y line as a dashed line
plt.plot(line_x, line_y, linestyle="dashed", color="black", label="x=y Line")
metric1_name = list(metric_values.keys())[0]
ax.set_title(
(
f"{model.model.method_name} - Validation {metric1_name}:"
f" {float(metric_values[metric1_name]):.3f}"
)
)
ax.set_xlabel("Real Values")
ax.set_ylabel("Predictions")
ax.legend()
return fig
[docs]
class ConfusionMatrix(ParentPlotter):
"""
Plots the confusion matrix for a model. This should be used for classification models only (binary or multiclass).
The data used to create the confusion matrix is either new data if using from_new_data or the original validation data
if using from_final_val_data.
"""
[docs]
def __init__(self):
super().__init__()
[docs]
@classmethod
def from_new_data(cls, model_list, output_paths, test_data_paths,
checkpoint_file_suffix=None, layer_mods=None):
"""
Confusion matrix using new data (i.e. data that was not used to train or validate the model).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2,
where the first element is the k=1 model and the second element is the k=2 model, etc.
For train/test models, this is a list of length 1.
output_paths: dict
Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.
test_data_paths: dict
Dictionary of the paths to the new data. The keys are the names of the data types (e.g. "tabular1", "image").
checkpoint_file_suffix: str, optional
Suffix that is on the trained model checkpoint files. e.g. "_firsttry". Added by the user.
Returns
-------
figure: matplotlib.pyplot.figure
The figure of the plot.
Raises
------
ValueError
If the model is not a list.
If the model is a list of length > 1 but kfold_flag is False.
If the model is a list of length 1 but kfold_flag is True.
If the model is an empty list.
"""
if not isinstance(model_list, list):
raise ValueError(
(
"Argument 'model_list' is not a list. "
"Please check the model and the function input."
"If you are using a train/test model, the single model must be in a list of length 1."
)
)
if len(model_list) > 1: # kfold model
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_new_kfold_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix, layer_mods)
figure = cls.confusion_matrix_kfold(
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
elif len(model_list) == 1: # train/test model
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_new_tt_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix,
layer_mods)
# plot the figure
figure = cls.confusion_matrix_tt(
model_list, val_reals, val_preds, metric_values
)
else:
raise ValueError("Argument 'model_list' is an empty list. ")
return figure
[docs]
@classmethod
def from_final_val_data(cls, model_list):
"""
Confusion matrix using the final validation data (i.e. the data that was used to evaluate the model
when the model training was complete).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first
element is the k=1 model and the second element is the k=2 model, etc.
For train/test models, this is a list of length 1.
Returns
-------
figure: matplotlib.pyplot.figure
The figure of the plot.
Raises
------
ValueError
If the model is not a list.
If the model is a list of length > 1 but kfold_flag is False.
If the model is a list of length 1 but kfold_flag is True.
If the model is an empty list.
"""
if not isinstance(model_list, list):
raise ValueError(
(
"Argument 'model_list' is not a list. "
"Please check the model and the function input."
"If you are using a train/test model, the single model must be in a list of length 1."
)
)
if len(model_list) > 1: # kfold model
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_kfold_data_from_model(model_list)
figure = cls.confusion_matrix_kfold(
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
)
elif len(model_list) == 1: # train/test model
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_tt_data_from_model(model_list)
figure = cls.confusion_matrix_tt(
model_list, val_reals, val_preds, metric_values
)
else:
raise ValueError(("Argument 'model_list' is an empty list. "))
return figure
[docs]
@classmethod
def confusion_matrix_tt(cls, model_list, val_reals, val_preds, metric_values):
"""
Confusion matrix for a train/test model. This function should be called within the ConfusionMatrix class
after the train/test data has been obtained from the model (either old data or new data).
Parameters
----------
model_list: list
A list of length 1 containing the trained pytorch_lightning model.
val_reals: torch.Tensor
Torch.Tensor of the real values for the new data set.
val_preds: torch.Tensor
Torch.Tensor of the predicted values for the new data set.
metric_values: dict
Dictionary of the metrics for the model.
The keys are the names of the metrics and the values are the metric values for the model.
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
"""
conf_matrix = confusion_matrix(y_true=val_reals, y_pred=val_preds)
# Create a figure and axis for the plot
fig, ax = plt.subplots(figsize=(7.5, 7.5))
# Plot the confusion matrix as a heatmap
ax.matshow(conf_matrix, cmap=plt.cm.RdPu, alpha=0.3)
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
# Add the value of each cell to the plot
ax.text(
x=j,
y=i,
s=conf_matrix[i, j],
va="center",
ha="center",
size="xx-large",
)
plt.xlabel("Predictions", fontsize=18)
plt.ylabel("Actuals", fontsize=18)
metric1_name = list(metric_values.keys())[0]
plt.title(
f"{model_list[0].model.method_name} - Validation {metric1_name}: {float(metric_values[metric1_name]):.3f}"
)
plt.tight_layout()
return fig
[docs]
@classmethod
def confusion_matrix_kfold(
cls,
model_list,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
):
"""
Confusion matrix for a kfold model. This function should be called within the ConfusionMatrix class
after the k-fold data has been obtained from the model (either old data or new data).
Parameters
----------
model_list: list
List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first
element is the k=1 model and the second element is the k=2 model, etc.
val_reals: list
List of torch.Tensors of the real values for the new data set for each fold.
val_preds: list
List of torch.Tensors of the predicted values for the new data set for each fold.
metrics_per_fold: dict
Dictionary of the metrics for each fold.
The keys are the names of the metrics and the values are lists of the metric values for each fold.
overall_kfold_metrics: dict
Dictionary of the overall kfold metrics.
The keys are the names of the metrics and the values are the metric values for the overall kfold,
meaning the metric values for the concatenated new data over all the folds.
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
"""
first_fold_model = model_list[0]
metric_names = list(metrics_per_fold.keys())
N = len(model_list)
cols = 3
rows = int(math.ceil(N / cols))
fig = plt.figure(constrained_layout=True, figsize=(15, 6))
subplots = fig.subfigures(1, 2)
ax0 = subplots[0].subplots(1, 1)
gs = gridspec.GridSpec(
rows,
cols,
hspace=0.5,
wspace=0.5,
)
for n in range(N):
if n == 0:
ax1 = subplots[1].add_subplot(gs[n])
ax_og = ax1
else:
ax1 = subplots[1].add_subplot(gs[n], sharey=ax_og, sharex=ax_og)
# get real and predicted values for the current fold
reals = val_reals[n]
preds = val_preds[n]
conf_matrix = confusion_matrix(y_true=reals, y_pred=preds.squeeze())
ax1.matshow(conf_matrix, cmap=plt.cm.RdPu, alpha=0.5)
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
# Add the value of each cell to the plot
ax1.text(
x=j,
y=i,
s=conf_matrix[i, j],
va="center",
ha="center",
size="large",
)
ax1.set_xlabel("Predictions", fontsize=10)
ax1.set_ylabel("Actuals", fontsize=10)
ax1.set_title(
f"Fold {n + 1}:\n{metric_names[0]}={float(metrics_per_fold[metric_names[0].lower()][n]):.3f}"
)
# gs.tight_layout(fig)
all_val_reals = torch.cat(val_reals, dim=-1)
all_val_preds = torch.cat(val_preds, dim=-1)
# plot all real vs. predicted values
conf_matrix = confusion_matrix(
y_true=all_val_reals, y_pred=all_val_preds.squeeze()
)
# Plot the confusion matrix as a heatmap
ax0.matshow(conf_matrix, cmap=plt.cm.RdPu, alpha=0.3)
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
# Add the value of each cell to the plot
ax0.text(
x=j,
y=i,
s=conf_matrix[i, j],
va="center",
ha="center",
size="xx-large",
)
ax0.set_xlabel("Predictions", fontsize=18)
ax0.set_ylabel("Actuals", fontsize=18)
ax0.set_title(
(
f"{first_fold_model.model.method_name}: {metric_names[0]}"
f"={float(overall_kfold_metrics[metric_names[0]]):.3f}"
)
)
# Set the overall title for the entire figure
fig.suptitle(f"{first_fold_model.model.__class__.__name__}: confusion matrix")
return fig
[docs]
class ModelComparison(ParentPlotter):
"""
Plots the performance of multiple models on a single plot. Currently (as of 2023-10-11) this is only
implemented for from_final_val_data because it is not clear how to implement from_new_data for graph-based models, so
they would have to be left out of the main plot (which feels wrong tbh).
"""
[docs]
def __init__(self):
super().__init__()
[docs]
@classmethod
def from_final_val_data(cls, model_dict):
"""
Plotting function for comparing models on metrics using the final validation data (i.e. the data that was used to evaluate the model
when the model training was complete).
Produces a violin plot if kfold_flag is True and a bar plot if kfold_flag is False.
Parameters
----------
model_dict: dict
Dictionary of trained pytorch_lightning models.
Keys are the names of the models and values are lists of the trained pytorch_lightning models.
If kfold_flag is True, the lists must be of length > 1 (and the length of num_k)
If kfold_flag is False, the lists must be of length 1 (meaning there is only one model for each key).
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
df: pandas.DataFrame
The dataframe of the metrics.
"""
# error if model_dict isn't a dict
if not isinstance(model_dict, dict):
raise ValueError(
(
"Argument 'model_dict' is not a dict. "
"'model_dict' should have keys as the model names and values as lists of trained models. "
)
)
comparing_models_metrics = {}
# check for empty lists in dict
for model in model_dict:
model_list = model_dict[model]
if len(model_list) == 0:
raise ValueError(
(
"Empty list of models has been passed into the ModelComparison.from_final_val_data. "
"There is an empty list somewhere in the model_dict. Please check the model_dict."
)
)
# get kfold_flag from first model in model_dict
first_model = list(model_dict.values())[0][0]
first_list_of_models = list(model_dict.values())[0]
if len(first_list_of_models) > 1:
kfold = True
else:
kfold = False
if kfold:
overall_kfold_metrics_dict = {}
for model in model_dict:
model_list = model_dict[model] # list of length k of trained models
# error if list is of length 1
if len(model_list) == 1:
raise ValueError(
(
"List of models in model_dict has length 1 but the kfold_flag is True. "
"K-fold training should produce a list of models of length > 1. "
"Please check the model_dict and the kfold_flag."
)
)
model_method_name = model_list[0].model.method_name
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_kfold_data_from_model(model_list)
comparing_models_metrics[model_method_name] = metrics_per_fold
overall_kfold_metrics_dict[model_method_name] = overall_kfold_metrics
figure = cls.kfold_comparison_plot(comparing_models_metrics)
df = cls.get_performance_dataframe(
comparing_models_metrics, overall_kfold_metrics_dict, kfold
)
else:
for model in model_dict:
model_list = model_dict[model] # list of length 1 of trained model
# error if list is of length > 1
if len(model_list) > 1:
raise ValueError(
(
"List of models in model_dict has length > 1 but the kfold_flag is False. "
"Train/test training should produce a list of models of length 1. "
"Please check the model_dict and the kfold_flag."
)
)
model_method_name = model_list[0].model.method_name
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_tt_data_from_model(model_list)
comparing_models_metrics[model_method_name] = metric_values
figure = cls.train_test_comparison_plot(comparing_models_metrics)
df = cls.get_performance_dataframe(
comparing_models_metrics, None, kfold
)
return figure, df
[docs]
@classmethod
def from_new_data(cls, model_dict, output_paths, test_data_paths, checkpoint_file_suffix=None,
layer_mods=None):
"""
Plotting function for comparing models on metrics using new data (i.e. data that was not used to train or validate the model).
Produces a violin plot if kfold_flag is True and a bar plot if kfold_flag is False.
Parameters
----------
model_dict: dict
Dictionary of trained pytorch_lightning models.
Keys are the names of the models and values are lists of the trained pytorch_lightning models.
If kfold_flag is True, the lists must be of length > 1 (and the length of num_k)
If kfold_flag is False, the lists must be of length 1 (meaning there is only one model for each key).
output_paths: dict
Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.
test_data_paths: dict
Dictionary of the paths to the new data. The keys are the names of the data types (e.g. "tabular1", "image").
checkpoint_file_suffix: str, optional
Suffix that is on the trained model checkpoint files. e.g. "_firsttry". Added by the user.
Default is None.
layer_mods: dict, optional
Dictionary of the layer modifications to make to the model.
Returns
-------
fig: matplotlib.pyplot.figure
The figure of the plot.
df: pandas.DataFrame
The dataframe of the metrics.
"""
comparing_models_metrics = {}
if not isinstance(model_dict, dict):
raise ValueError(
(
"Argument 'model_dict' is not a dict. "
"'model_dict' should have keys as the model names and values as lists of trained models. "
)
)
# check for empty lists in dict
for model in model_dict:
model_list = model_dict[model]
if len(model_list) == 0:
raise ValueError(
(
"Empty list of models has been passed into the ModelComparison.from_new_data. "
"There is an empty list somewhere in the model_dict. Please check the model_dict."
)
)
# get kfold_flag from first model in model_dict
first_model = list(model_dict.values())[0][0]
first_model_list = list(model_dict.values())[0]
if len(first_model_list) > 1:
kfold = True
else:
kfold = False
if kfold:
overall_kfold_metrics_dict = {}
for model in model_dict:
model_list = model_dict[model] # list of length k of trained models
# if model is a graph-based model, skip it
if hasattr(model_list[0].model, "graph_maker"):
raise Warning(
(
"Graph-based models are not currently supported for the ModelComparison.from_new_data. "
"The graph-based models will be skipped."
)
)
continue
if len(model_list) == 1:
raise ValueError(
(
"List of models in model_dict has length 1 but the kfold_flag is True. "
"K-fold training should produce a list of models of length > 1. "
"Please check the model_dict and the kfold_flag."
)
)
model_method_name = model_list[0].model.method_name
(
train_reals,
train_preds,
val_reals,
val_preds,
metrics_per_fold,
overall_kfold_metrics,
) = cls.get_new_kfold_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix,
layer_mods)
comparing_models_metrics[model_method_name] = metrics_per_fold
overall_kfold_metrics_dict[model_method_name] = overall_kfold_metrics
figure = cls.kfold_comparison_plot(comparing_models_metrics)
df = cls.get_performance_dataframe(
comparing_models_metrics, overall_kfold_metrics_dict, kfold
)
else:
for model in model_dict:
model_list = model_dict[model] # list of length 1 of trained model
# if model is a graph-based model, skip it
if hasattr(model_list[0].model, "graph_maker"):
raise Warning(
(
"Graph-based models are not currently supported for the ModelComparison.from_new_data. "
"The graph-based models will be skipped."
)
)
continue
if len(model_list) > 1:
raise ValueError(
(
"List of models in model_dict has length > 1 but the kfold_flag is False. "
"Train/test training should produce a list of models of length 1. "
"Please check the model_dict and the kfold_flag."
)
)
model_method_name = model_list[0].model.method_name
(
train_reals,
train_preds,
val_reals,
val_preds,
metric_values,
) = cls.get_new_tt_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix, layer_mods)
comparing_models_metrics[model_method_name] = metric_values
figure = cls.train_test_comparison_plot(comparing_models_metrics)
df = cls.get_performance_dataframe(
comparing_models_metrics, None, kfold
)
return figure, df
[docs]
@classmethod
def kfold_comparison_plot(cls, comparing_models_metrics):
"""Plotting function for comparing models on kfold metrics. Produces a violin plot.
Parameters
----------
comparing_models_metrics: dict
Dictionary of metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}
Returns
-------
fig: matplotlib figure
Figure containing the violin plots.
"""
# get method names and metric names
method_names = list(
comparing_models_metrics.keys()
) # [method1name, method2name,...]
metric1name = list(comparing_models_metrics[method_names[0]].keys())[0]
metric2name = list(comparing_models_metrics[method_names[0]].keys())[1]
# get metric values for each method
metric_1_values = [
comparing_models_metrics[method][metric1name] for method in method_names
]
metric_2_values = [
comparing_models_metrics[method][metric2name] for method in method_names
]
# Calculate mean or median of metric_1_values for sorting
metric_1_means = np.array(metric_1_values).mean(
axis=1
) # Change to median if needed
sorted_indices = np.argsort(metric_1_means)
# Reorder method names, metric values, and other related data
method_names = np.array(method_names)[sorted_indices]
metric_1_values = np.array(metric_1_values)[sorted_indices].transpose()
metric_2_values = np.array(metric_2_values)[sorted_indices].transpose()
# create figure 1x2 subplots
# create figure 1x2 subplots
fig, ax = plt.subplots(1, 2)
ax[0].grid()
ax[1].grid()
# create violin plots for each metric
bp = ax[0].violinplot(metric_1_values, vert=False, showmeans=True)
def set_violin_colors(instance, colour):
for pc in instance["bodies"]:
pc.set_facecolor(colour)
pc.set_edgecolor("black")
pc.set_alpha(0.5)
instance["cmeans"].set_edgecolor("black")
instance["cmins"].set_edgecolor("black")
instance["cmaxes"].set_edgecolor("black")
instance["cbars"].set_edgecolor("black")
set_violin_colors(bp, "violet")
ax[0].yaxis.set_ticks(np.arange(len(method_names)) + 1)
ax[0].set_yticklabels(method_names)
ax[0].get_xaxis().tick_bottom()
ax[0].set_xlim(right=1.0)
bp2 = ax[1].violinplot(metric_2_values, vert=False, showmeans=True)
set_violin_colors(bp2, "powderblue")
ax[1].yaxis.set_ticks(np.arange(len(method_names)) + 1)
ax[1].set_yticklabels([] * len(metric_2_values))
ax[1].get_xaxis().tick_bottom()
# set titles and limits
ax[0].set_title(metric1name)
ax[1].set_title(metric2name)
ax[1].set_xlim(left=0.0)
plt.suptitle("Distribution of metrics between cross-validation folds")
plt.tight_layout()
return fig
[docs]
@classmethod
def train_test_comparison_plot(cls, comparing_models_metrics):
"""Plotting function for comparing models on train and test metrics. Produces a horizontal bar chart.
Parameters
----------
comparing_models_metrics: dict
Dictionary of metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}
Returns
-------
fig: matplotlib figure
Figure containing the horizontal bar chart.
"""
method_names = list(
comparing_models_metrics.keys()
) # [method1name, method2name,...]
metric1name = list(comparing_models_metrics[method_names[0]].keys())[0]
metric2name = list(comparing_models_metrics[method_names[0]].keys())[1]
# get metric values for each method
metric_1_values = [
comparing_models_metrics[method][metric1name] for method in method_names
]
metric_2_values = [
comparing_models_metrics[method][metric2name] for method in method_names
]
sorted_indices = np.argsort(metric_1_values)
method_names = np.array(method_names)[sorted_indices]
metric_1_values = np.array(metric_1_values)[sorted_indices]
metric_2_values = np.array(metric_2_values)[sorted_indices]
# Create an array of indices for the x-axis
y_indices = np.arange(len(method_names))
# Width of the bars
bar_width = 0.35
# Create the figure and the primary y-axis
fig, ax = plt.subplots(1, 2)
ax[0].grid()
ax[1].grid()
# Create the first bar chart using the primary y-axis (ax1)
bars1 = ax[0].barh(
y_indices, metric_1_values, bar_width, color="violet", edgecolor="purple"
)
# black dashed line at x=0
ax[0].axvline(x=0, color="black", linestyle="--", alpha=0.5)
ax[0].yaxis.set_ticks(np.arange(len(method_names)))
ax[0].set_yticklabels(method_names)
ax[0].get_xaxis().tick_bottom()
ax[0].set_xlim(right=1.0)
# Create a secondary y-axis for the second metric
# ax2 = ax1.twiny()
# Create the second bar chart using the secondary y-axis (ax2)
bars2 = ax[1].barh(
y_indices,
metric_2_values,
bar_width,
color="powderblue",
edgecolor="steelblue",
)
ax[1].yaxis.set_ticks(np.arange(len(method_names)))
ax[1].set_yticklabels([] * len(metric_2_values))
ax[1].get_xaxis().tick_bottom()
# set titles and limits
ax[0].set_title(metric1name)
ax[1].set_title(metric2name)
ax[1].set_xlim(left=0.0)
# Show the plot
plt.suptitle("Model Performance Comparison")
plt.tight_layout()
return fig