fusilli.eval

This module contains classes and functions for evaluating trained models (i.e. plotting results from training). The setup for this module has been inspired by the scikit-learn API for plotting results, which involves each plot being a class with a from_final_val_data method that takes in a trained model and returns a plot with the validation data, and a from_new_data method that takes in a trained model and new data and returns a plot.

Classes

ConfusionMatrix()

Plots the confusion matrix for a model.

ModelComparison()

Plots the performance of multiple models on a single plot.

ParentPlotter()

Parent class for all plot classes.

RealsVsPreds()

Plots the real values vs the predicted values for a model.

class ConfusionMatrix[source]

Bases: ParentPlotter

Plots the confusion matrix for a model. This should be used for classification models only (binary or multiclass). The data used to create the confusion matrix is either new data if using from_new_data or the original validation data if using from_final_val_data.

__init__()[source]
classmethod confusion_matrix_kfold(model_list, val_reals, val_preds, metrics_per_fold, overall_kfold_metrics)[source]

Confusion matrix for a kfold model. This function should be called within the ConfusionMatrix class after the k-fold data has been obtained from the model (either old data or new data).

Parameters:
  • model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc.

  • val_reals (list) – List of torch.Tensors of the real values for the new data set for each fold.

  • val_preds (list) – List of torch.Tensors of the predicted values for the new data set for each fold.

  • metrics_per_fold (dict) – Dictionary of the metrics for each fold. The keys are the names of the metrics and the values are lists of the metric values for each fold.

  • overall_kfold_metrics (dict) – Dictionary of the overall kfold metrics. The keys are the names of the metrics and the values are the metric values for the overall kfold, meaning the metric values for the concatenated new data over all the folds.

Returns:

fig – The figure of the plot.

Return type:

matplotlib.pyplot.figure

classmethod confusion_matrix_tt(model_list, val_reals, val_preds, metric_values)[source]

Confusion matrix for a train/test model. This function should be called within the ConfusionMatrix class after the train/test data has been obtained from the model (either old data or new data).

Parameters:
  • model_list (list) – A list of length 1 containing the trained pytorch_lightning model.

  • val_reals (torch.Tensor) – Torch.Tensor of the real values for the new data set.

  • val_preds (torch.Tensor) – Torch.Tensor of the predicted values for the new data set.

  • metric_values (dict) – Dictionary of the metrics for the model. The keys are the names of the metrics and the values are the metric values for the model.

Returns:

fig – The figure of the plot.

Return type:

matplotlib.pyplot.figure

classmethod from_final_val_data(model_list)[source]

Confusion matrix using the final validation data (i.e. the data that was used to evaluate the model when the model training was complete).

Parameters:

model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc. For train/test models, this is a list of length 1.

Returns:

figure – The figure of the plot.

Return type:

matplotlib.pyplot.figure

Raises:

ValueError – If the model is not a list. If the model is a list of length > 1 but kfold_flag is False. If the model is a list of length 1 but kfold_flag is True. If the model is an empty list.

classmethod from_new_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None)[source]

Confusion matrix using new data (i.e. data that was not used to train or validate the model).

Parameters:
  • model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc. For train/test models, this is a list of length 1.

  • output_paths (dict) – Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.

  • test_data_paths (dict) – Dictionary of the paths to the new data. The keys are the names of the data types (e.g. “tabular1”, “image”).

  • checkpoint_file_suffix (str, optional) – Suffix that is on the trained model checkpoint files. e.g. “_firsttry”. Added by the user.

Returns:

figure – The figure of the plot.

Return type:

matplotlib.pyplot.figure

Raises:

ValueError – If the model is not a list. If the model is a list of length > 1 but kfold_flag is False. If the model is a list of length 1 but kfold_flag is True. If the model is an empty list.

class ModelComparison[source]

Bases: ParentPlotter

Plots the performance of multiple models on a single plot. Currently (as of 2023-10-11) this is only implemented for from_final_val_data because it is not clear how to implement from_new_data for graph-based models, so they would have to be left out of the main plot (which feels wrong tbh).

__init__()[source]
classmethod from_final_val_data(model_dict)[source]

Plotting function for comparing models on metrics using the final validation data (i.e. the data that was used to evaluate the model when the model training was complete). Produces a violin plot if kfold_flag is True and a bar plot if kfold_flag is False.

Parameters:

model_dict (dict) – Dictionary of trained pytorch_lightning models. Keys are the names of the models and values are lists of the trained pytorch_lightning models. If kfold_flag is True, the lists must be of length > 1 (and the length of num_k) If kfold_flag is False, the lists must be of length 1 (meaning there is only one model for each key).

Returns:

  • fig (matplotlib.pyplot.figure) – The figure of the plot.

  • df (pandas.DataFrame) – The dataframe of the metrics.

classmethod from_new_data(model_dict, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None)[source]

Plotting function for comparing models on metrics using new data (i.e. data that was not used to train or validate the model). Produces a violin plot if kfold_flag is True and a bar plot if kfold_flag is False.

Parameters:
  • model_dict (dict) – Dictionary of trained pytorch_lightning models. Keys are the names of the models and values are lists of the trained pytorch_lightning models. If kfold_flag is True, the lists must be of length > 1 (and the length of num_k) If kfold_flag is False, the lists must be of length 1 (meaning there is only one model for each key).

  • output_paths (dict) – Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.

  • test_data_paths (dict) – Dictionary of the paths to the new data. The keys are the names of the data types (e.g. “tabular1”, “image”).

  • checkpoint_file_suffix (str, optional) – Suffix that is on the trained model checkpoint files. e.g. “_firsttry”. Added by the user. Default is None.

  • layer_mods (dict, optional) – Dictionary of the layer modifications to make to the model.

Returns:

  • fig (matplotlib.pyplot.figure) – The figure of the plot.

  • df (pandas.DataFrame) – The dataframe of the metrics.

classmethod get_performance_dataframe(comparing_models_metrics, overall_kfold_metrics_dict, kfold_flag)[source]

Get a dataframe of the performance metrics for each model. For kfold models, the dataframe contains the overall kfold metrics and the metrics for each fold. For train/test models, the dataframe contains the metrics for the train and test sets.

Parameters:
  • comparing_models_metrics (dict) – Dictionary of metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}. Metric values can be float or int.

  • overall_kfold_metrics_dict (dict) – Dictionary of overall kfold metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}. This is only needed for kfold models.

  • kfold_flag (bool) – True if the model is a kfold model, False if the model is a train/test model.

Returns:

df – Dataframe of the performance metrics for each model.

Return type:

pandas.DataFrame

classmethod kfold_comparison_plot(comparing_models_metrics)[source]

Plotting function for comparing models on kfold metrics. Produces a violin plot.

Parameters:

comparing_models_metrics (dict) – Dictionary of metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}

Returns:

fig – Figure containing the violin plots.

Return type:

matplotlib figure

classmethod train_test_comparison_plot(comparing_models_metrics)[source]

Plotting function for comparing models on train and test metrics. Produces a horizontal bar chart.

Parameters:

comparing_models_metrics (dict) – Dictionary of metrics. Keys are the model names and values are dict {metric1name: metric_value, metric2name: metric_value}

Returns:

fig – Figure containing the horizontal bar chart.

Return type:

matplotlib figure

class ParentPlotter[source]

Bases: object

Parent class for all plot classes.

It includes methods that are used by multiple plot classes, such as obtaining final validation data from kfold and train/test models, and putting new data through the models for both kfold and train/test protocols.

__init__()[source]
classmethod get_kfold_data_from_model(model_list)[source]

Get the final validation data from a kfold model, meaning the data that was used to evaluate the model when the model training was complete.

Parameters:

model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc.

Returns:

  • train_reals (list) – List of torch.Tensors of the real values for the training set for each fold. This is stored in each trained model’s class instance.

  • train_preds (list) – List of torch.Tensors of the predicted values for the training set for each fold. This is stored in each trained model’s class instance.

  • val_reals (list) – List of torch.Tensors of the real values for the validation set for each fold. This is stored in each trained model’s class instance.

  • val_preds (list) – List of torch.Tensors of the predicted values for the validation set for each fold. This is stored in each trained model’s class instance.

  • metrics_per_fold (dict) – Dictionary of the metrics for each fold. The keys are the names of the metrics and the values are lists of the metric values for each fold.

  • overall_kfold_metrics (dict) – Dictionary of the overall kfold metrics. The keys are the names of the metrics and the values are the metric values for the overall kfold, meaning the metric values for the concatenated final validation data over all the folds.

classmethod get_new_kfold_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None)[source]

Get new data by running through trained model for a kfold model.

Parameters:
  • model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc.

  • output_paths (dict) – Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.

  • test_data_paths (dict) – Dictionary of the paths to the new data. The keys are the names of the data types (e.g. “tabular1”, “image”).

  • checkpoint_file_suffix (str, optional) – Suffix that is on the trained model checkpoint files. e.g. “_firsttry”. Added by the user. Default is None.

  • layer_mods (dict, optional) – Dictionary of the layer modifications to make to the model.

Returns:

  • train_reals (list) – List of torch.Tensors of the real values for the training set for each fold. This is stored in each trained model’s class instance.

  • train_preds (list) – List of torch.Tensors of the predicted values for the training set for each fold. This is stored in each trained model’s class instance.

  • val_reals (list) – List of torch.Tensors of the real values for the new data set for each fold.

  • val_preds (list) – List of torch.Tensors of the predicted values for the new data set for each fold. This was obtained by running the new data through the trained model.

  • metrics_per_fold (dict) – Dictionary of the metrics for each fold. The keys are the names of the metrics and the values are lists of the metric values for each fold.

  • overall_kfold_metrics (dict) – Dictionary of the overall kfold metrics. The keys are the names of the metrics and the values are the metric values for the overall kfold, meaning the metric values for the concatenated new data over all the folds.

Raises:

ValueError – If the model has a graph maker, it’s not supported yet for creating graphs from new data.

classmethod get_new_tt_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None)[source]

Get new data by running through trained model for a train/test model.

Parameters:
  • model_list (list) – A list of length 1 containing the trained pytorch_lightning model.

  • output_paths (dict) – Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.

  • test_data_paths (dict) – Dictionary of the paths to the new data. The keys are the names of the data types (e.g. “tabular1”, “image”).

  • checkpoint_file_suffix (str, optional) – Suffix that is on the trained model checkpoint files. e.g. “_firsttry”. Added by the user. Default is None.

  • layer_mods (dict, optional) – Dictionary of the layer modifications to make to the model.

Returns:

  • train_reals (torch.Tensor) – Torch.Tensor of the real values for the training set. This is stored in the trained model’s class instance.

  • train_preds (torch.Tensor) – Torch.Tensor of the predicted values for the training set. This is stored in the trained model’s class instance.

  • val_reals (torch.Tensor) – Torch.Tensor of the real values for the new data set.

  • val_preds (torch.Tensor) – Torch.Tensor of the predicted values for the new data set.

  • metric_values (dict) – Dictionary of the metrics for the model. The keys are the names of the metrics and the values are the metric values for the model.

Raises:

ValueError – If the model has a graph maker, it’s not supported yet for creating graphs from new data.

classmethod get_tt_data_from_model(model_list)[source]

Get the final validation data from a train/test model, meaning the data that was used to evaluate the model when the model training was complete.

Parameters:

model_list (list) – A list of length 1 containing the trained pytorch_lightning model.

Returns:

  • train_reals (torch.Tensor) – Torch.Tensor of the real values for the training set. This is stored in the trained model’s class instance.

  • train_preds (torch.Tensor) – Torch.Tensor of the predicted values for the training set. This is stored in the trained model’s class instance.

  • val_reals (torch.Tensor) – Torch.Tensor of the real values for the validation set. This is stored in the trained model’s class instance.

  • val_preds (torch.Tensor) – Torch.Tensor of the predicted values for the validation set. This is stored in the trained model’s class instance.

  • metric_values (dict) – Dictionary of the metrics for the model. The keys are the names of the metrics and the values are the metric values for the model.

class RealsVsPreds[source]

Bases: ParentPlotter

Plots the real values vs the predicted values for a model. Pink dots are the training data, green dots are the validation data. The validation data is either new data if using from_new_data or the original validation data if using from_final_val_data.

__init__()[source]
classmethod from_final_val_data(model_list)[source]

Reals vs preds plot using the final validation data (i.e. the data that was used to evaluate the model when the model training was complete).

Parameters:

model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc. For train/test models, this is a list of length 1.

Returns:

figure – The figure of the plot.

Return type:

matplotlib.pyplot.figure

Raises:

ValueError – If the model is not a list. If the model is a list of length > 1 but kfold_flag is False. If the model is a list of length 1 but kfold_flag is True. If the model is an empty list.

classmethod from_new_data(model_list, output_paths, test_data_paths, checkpoint_file_suffix=None, layer_mods=None)[source]

Reals vs preds plot using new data (i.e. data that was not used to train or validate the model).

Parameters:
  • model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc. For train/test models, this is a list of length 1.

  • output_paths (dict) – Dictionary of the output paths. Used for knowing where the checkpoint files are stored and where to save the plots.

  • test_data_paths (dict) – Dictionary of the paths to the new data. The keys are the names of the data types (e.g. “tabular1”, “image”).

  • checkpoint_file_suffix (str, optional) – Suffix that is on the trained model checkpoint files. e.g. “_firsttry”. Added by the user. Default is None.

  • layer_mods (dict, optional) – Dictionary of the layer modifications to make to the model.

Returns:

figure – The figure of the plot.

Return type:

matplotlib.pyplot.figure

Raises:

ValueError – If the model is not a list. If the model is a list of length > 1 but kfold_flag is False. If the model is a list of length 1 but kfold_flag is True. If the model is an empty list.

classmethod reals_vs_preds_kfold(model_list, val_reals, val_preds, metrics_per_fold, overall_kfold_metrics)[source]

Reals vs preds plot for a kfold model. This function should be called within the RealVsPreds class after the k-fold data has been obtained from the model (either old data or new data).

Parameters:
  • model_list (list) – List of trained pytorch_lightning models. For kfold models, this is a list of at least length 2, where the first element is the k=1 model and the second element is the k=2 model, etc.

  • val_reals (list) – List of torch.Tensors of the real values for the new data set for each fold.

  • val_preds (list) – List of torch.Tensors of the predicted values for the new data set for each fold.

  • metrics_per_fold (dict) – Dictionary of the metrics for each fold. The keys are the names of the metrics and the values are lists of the metric values for each fold.

  • overall_kfold_metrics (dict) – Dictionary of the overall kfold metrics. The keys are the names of the metrics and the values are the metric values for the overall kfold, meaning the metric values for the concatenated new data over all the folds.

Returns:

fig – The figure of the plot.

Return type:

matplotlib.pyplot.figure

classmethod reals_vs_preds_tt(model_list, train_reals, train_preds, val_reals, val_preds, metric_values)[source]

Reals vs preds plot for a train/test model. This function should be called within the RealVsPreds class after the train/test data has been obtained from the model (either old data or new data).

Parameters:
  • model_list (list) – A list of length 1 containing the trained pytorch_lightning model.

  • train_reals (torch.Tensor) – Torch.Tensor of the real values for the training set.

  • train_preds (torch.Tensor) – Torch.Tensor of the predicted values for the training set.

  • val_reals (torch.Tensor) – Torch.Tensor of the real values for the new data set.

  • val_preds (torch.Tensor) – Torch.Tensor of the predicted values for the new data set.

  • metric_values (dict) – Dictionary of the metrics for the model. The keys are the names of the metrics and the values are the metric values for the model.

Returns:

fig – The figure of the plot.

Return type:

matplotlib.pyplot.figure