.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_comparison/plot_model_comparison_loop_kfold.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_comparison_plot_model_comparison_loop_kfold.py: Comparing All Fusion Models ==================================================================== Welcome to the "Comparing All Fusion Models" tutorial! In this tutorial, we'll explore how to train and compare multiple fusion models for a multiclass classification task using k-fold cross-validation with multimodal tabular data. This tutorial is designed to help you understand and implement key features, including: - ๐Ÿ“ฅ Importing fusion models based on modality types. - ๐Ÿšฒ Setting training parameters for your models - ๐Ÿ”ฎ Specifying the data to be used for training and testing. - ๐Ÿงช Training and evaluating multiple fusion models. - ๐Ÿ“ˆ Visualising the results of individual models. - ๐Ÿ“Š Comparing the performance of different models. Let's dive into each of these steps in detail: 1. **Importing Fusion Models:** We begin by importing fusion models based on modality types. These models will be used in our multiclass classification task, and we'll explore various fusion strategies. 2. **Setting the Training Parameters:** To ensure consistent and controlled training, we define training parameters. These parameters include enabling k-fold cross-validation, specifying the prediction type, and setting the batch size for training. 3. **Specify data to be used:** In this step, we specify the data to be used for training and testing. We will be using MNIST data for this, split into top and bottom halves for our two tabular modalities. 4. **Training All Fusion Models:** Now, we train all the selected fusion models using the data and the defined training parameters. We'll monitor the performance of each model during training and store the results for later analysis. 5. **Plotting Individual Model Results:** After training, we visualise the performance of each individual model. We create plots that show loss curves and performance metrics to help us understand how each model performed. 6. **Comparing Model Performance:** To gain insights into which fusion models perform best, we compare their performance using a violin chart. This chart provides a clear overview of how each model's performance metrics compare. 7. **Saving the Results:** Finally, we save the performance results of all the models as a structured DataFrame. This data can be further analyzed, exported to a CSV file, or used for additional experiments. Now, let's walk through each of these steps in code and detail. Let's get started! ๐ŸŒธ .. GENERATED FROM PYTHON SOURCE LINES 46-60 .. code-block:: Python import matplotlib.pyplot as plt from tqdm.auto import tqdm import os from fusilli.data import prepare_fusion_data from fusilli.eval import ConfusionMatrix, ModelComparison from fusilli.train import train_and_save_models from fusilli.utils.model_chooser import import_chosen_fusion_models # sphinx_gallery_thumbnail_number = -1 # from IPython.utils import io # for hiding the tqdm progress bar .. GENERATED FROM PYTHON SOURCE LINES 61-72 1. Import fusion models ๐Ÿ” --------------------------- Here we import the fusion models to be compared. The models are imported using the :func:`~fusilli.utils.model_chooser.get_models` function, which takes a dictionary of conditions as an input. The conditions are the attributes of the models, e.g. the class name, the modality type, etc. The function returns a dataframe of the models that match the conditions. The dataframe contains the method name, the class name, the modality type, the fusion type, the path to the model, and the path to the model's parent class. The paths are used to import the models with the :func:`importlib.import_module`. We're importing all the fusion models that use only tabular data for this example (either uni-modal or multi-modal). .. GENERATED FROM PYTHON SOURCE LINES 72-79 .. code-block:: Python model_conditions = { "modality_type": ["tabular1", "tabular2", "tabular_tabular"], } fusion_models = import_chosen_fusion_models(model_conditions) .. rst-class:: sphx-glr-script-out .. code-block:: none Imported methods: ['Tabular1 uni-modal' 'Tabular2 uni-modal' 'Concatenating tabular feature maps' 'Concatenating tabular data' 'Channel-wise multiplication net (tabular)' 'Tabular Crossmodal multi-head attention' 'Tabular decision' 'MCVAE Tabular' 'Edge Correlation GNN' 'Attention-weighted GNN' 'Activation function and tabular self-attention' 'Activation function map fusion'] .. GENERATED FROM PYTHON SOURCE LINES 80-98 2. Set the training parameters ๐ŸŽฏ --------------------------------- Let's configure our training parameters. For training and testing, the necessary parameters are: - Paths to the input data files. - Paths to the output directories. - ``prediction_task``: the type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``. Some optional parameters are: - ``kfold``: a boolean of whether to use k-fold cross-validation (True) or not (False). By default, this is set to False. - ``num_folds``: the number of folds to use. It can't be ``k=1``. - ``wandb_logging``: a boolean of whether to log the results using Weights and Biases (True) or not (False). Default is False. - ``test_size``: the proportion of the dataset to include in the test split. Default is 0.2. - ``batch_size``: the batch size to use for training. Default is 8. - ``multiclass_dimensions``: the number of classes to use for multiclass classification. Default is None unless ``prediction_task`` is ``multiclass``. - ``max_epochs``: the maximum number of epochs to train for. Default is 1000. .. GENERATED FROM PYTHON SOURCE LINES 98-129 .. code-block:: Python # Multiclass classification task (predicting one of 10 classes) prediction_task = "multiclass" number_of_classes = 10 # Set the batch size batch_size = 32 # Enable k-fold cross-validation with k=3 kfold = True num_folds = 3 # Setting output directories output_paths = { "losses": "loss_logs/model_comparison_loop_kfold", "checkpoints": "checkpoints/model_comparison_loop_kfold", "figures": "figures/model_comparison_loop_kfold", } os.makedirs(output_paths["losses"], exist_ok=True) os.makedirs(output_paths["checkpoints"], exist_ok=True) os.makedirs(output_paths["figures"], exist_ok=True) # Clearing the loss logs directory (only for the example notebooks) for dir in os.listdir(output_paths["losses"]): # remove files for file in os.listdir(os.path.join(output_paths["losses"], dir)): os.remove(os.path.join(output_paths["losses"], dir, file)) # remove dir os.rmdir(os.path.join(output_paths["losses"], dir)) .. GENERATED FROM PYTHON SOURCE LINES 130-133 3. Specifying input file paths ๐Ÿ”ฎ -------------------------------- We're using the MNIST dataset for this example, and the CSV files are stored in the ``_static/mnist_data`` directory with the documentation files. .. GENERATED FROM PYTHON SOURCE LINES 133-140 .. code-block:: Python data_paths = { "tabular1": "../../_static/mnist_data/mnist1.csv", "tabular2": "../../_static/mnist_data/mnist2.csv", "image": "", } .. GENERATED FROM PYTHON SOURCE LINES 141-145 4. Training the all the fusion models ๐Ÿ ----------------------------------------- In this section, we train all the fusion models using the data and specified parameters. We store the results of each model for later analysis. .. GENERATED FROM PYTHON SOURCE LINES 145-177 .. code-block:: Python # Using %%capture to hide the progress bar and plots (there are a lot of them!) all_trained_models = {} for i, fusion_model in enumerate(fusion_models): fusion_model_name = fusion_model.__name__ print(f"Running model {fusion_model_name}") # Get data module data_module = prepare_fusion_data(prediction_task=prediction_task, multiclass_dimensions=number_of_classes, fusion_model=fusion_model, data_paths=data_paths, output_paths=output_paths, kfold=kfold, num_folds=num_folds, batch_size=batch_size) # Train and test single_model_list = train_and_save_models( data_module=data_module, fusion_model=fusion_model, enable_checkpointing=False, # We're not saving the trained models for this example show_loss_plot=True, # We'll show the loss plot for each model instead of saving it ) # Save to all_trained_models all_trained_models[fusion_model_name] = single_model_list plt.close("all") .. rst-class:: sphx-glr-script-out .. code-block:: none Running model Tabular1Unimodal Training: | | 0/? [00:00
auroc accuracy fold1_auroc fold2_auroc fold3_auroc fold1_accuracy fold2_accuracy fold3_accuracy
Method
Tabular1 uni-modal 0.926437 0.694 0.920627 0.953280 0.930301 0.658683 0.724551 0.698795
Tabular2 uni-modal 0.912041 0.622 0.896803 0.920812 0.932999 0.616766 0.598802 0.650602
Concatenating tabular feature maps 0.947997 0.742 0.934640 0.969213 0.964932 0.664671 0.796407 0.765060
Concatenating tabular data 0.890741 0.556 0.953003 0.500000 0.980494 0.730539 0.101796 0.837349
Channel-wise multiplication net (tabular) 0.697613 0.270 0.641849 0.769003 0.693229 0.239521 0.365269 0.204819
Tabular Crossmodal multi-head attention 0.945050 0.678 0.966650 0.915441 0.966022 0.748503 0.508982 0.777108
Tabular decision 0.971814 0.800 0.978862 0.977966 0.962375 0.862275 0.814371 0.722892
MCVAE Tabular 0.901287 0.534 0.926778 0.911758 0.899404 0.586826 0.532934 0.481928
Edge Correlation GNN 0.497640 0.096 0.504931 0.534522 0.474791 0.113772 0.101796 0.072289
Attention-weighted GNN 0.469901 0.094 0.500322 0.538634 0.462927 0.071856 0.095808 0.114458
Activation function and tabular self-attention 0.955697 0.754 0.954473 0.955870 0.970822 0.742515 0.724551 0.795181
Activation function map fusion 0.942065 0.726 0.927903 0.961230 0.947598 0.652695 0.778443 0.746988


.. rst-class:: sphx-glr-timing **Total running time of the script:** (7 minutes 21.974 seconds) .. _sphx_glr_download_auto_examples_model_comparison_plot_model_comparison_loop_kfold.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_model_comparison_loop_kfold.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_model_comparison_loop_kfold.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_model_comparison_loop_kfold.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_