.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/training_and_testing/plot_one_model_binary_kfold.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_training_and_testing_plot_one_model_binary_kfold.py: K-Fold Cross-Validation: Binary Classification ====================================================== ๐Ÿš€ In this tutorial, we'll explore binary classification using K-fold cross validation. We'll show you how to train a fusion model using K-Fold cross-validation with multimodal tabular data. Specifically, we're using the :class:`~.TabularCrossmodalMultiheadAttention` model. Data: The data we are using is 500 rows of the MNIST dataset, split into top and bottom halves as our two tabular modalities. The bottom half's values have been inverted to make the task more difficult. The prediction label will be whether the number is odd or even. Key Features: - ๐Ÿ“ฅ Importing a model based on its path. - ๐Ÿงช Training and testing a model with k-fold cross validation. - ๐Ÿ“ˆ Plotting the loss curves of each fold. - ๐Ÿ“Š Visualising the results of a single K-Fold model using the :class:`~.ConfusionMatrix` class. .. GENERATED FROM PYTHON SOURCE LINES 22-33 .. code-block:: Python import matplotlib.pyplot as plt from tqdm.auto import tqdm import os from fusilli.data import prepare_fusion_data from fusilli.eval import ConfusionMatrix from fusilli.train import train_and_save_models # sphinx_gallery_thumbnail_number = -1 .. GENERATED FROM PYTHON SOURCE LINES 34-38 1. Import the fusion model ๐Ÿ” -------------------------------- We're importing only one model for this example, the :class:`~.TabularCrossmodalMultiheadAttention` model. Instead of using the :func:`~fusilli.utils.model_chooser.import_chosen_fusion_models` function, we're importing the model directly like with any other library method. .. GENERATED FROM PYTHON SOURCE LINES 38-44 .. code-block:: Python from fusilli.fusionmodels.tabularfusion.crossmodal_att import ( TabularCrossmodalMultiheadAttention, ) .. GENERATED FROM PYTHON SOURCE LINES 45-63 2. Set the training parameters ๐ŸŽฏ ----------------------------------- Now we're configuring our training parameters. For training and testing, the necessary parameters are: - Paths to the input data files. - Paths to the output directories. - ``prediction_task``: the type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``. Some optional parameters are: - ``kfold``: a boolean of whether to use k-fold cross-validation (True) or not (False). By default, this is set to False. - ``num_folds``: the number of folds to use. It can't be ``k=1``. - ``wandb_logging``: a boolean of whether to log the results using Weights and Biases (True) or not (False). Default is False. - ``test_size``: the proportion of the dataset to include in the test split. Default is 0.2. - ``batch_size``: the batch size to use for training. Default is 8. - ``multiclass_dimensions``: the number of classes to use for multiclass classification. Default is None unless ``prediction_task`` is ``multiclass``. - ``max_epochs``: the maximum number of epochs to train for. Default is 1000. .. GENERATED FROM PYTHON SOURCE LINES 63-93 .. code-block:: Python # Binary task (predicting a binary variable - 0 or 1) prediction_task = "binary" # Set the batch size batch_size = 32 # Enable k-fold cross-validation with k=5 kfold = True num_folds = 5 # Setting output directories output_paths = { "losses": "loss_logs/one_model_binary_kfold", "checkpoints": "checkpoints/one_model_binary_kfold", "figures": "figures/one_model_binary_kfold", } # Create the output directories if they don't exist for path in output_paths.values(): os.makedirs(path, exist_ok=True) # Clearing the loss logs directory (only for the example notebooks) for dir in os.listdir(output_paths["losses"]): # remove files for file in os.listdir(os.path.join(output_paths["losses"], dir)): os.remove(os.path.join(output_paths["losses"], dir, file)) # remove dir os.rmdir(os.path.join(output_paths["losses"], dir)) .. GENERATED FROM PYTHON SOURCE LINES 94-97 3. Specifying input file paths ๐Ÿ”ฎ -------------------------------- We're using the MNIST dataset for this example, and the CSV files are stored in the ``_static/mnist_data`` directory with the documentation files. .. GENERATED FROM PYTHON SOURCE LINES 97-104 .. code-block:: Python data_paths = { "tabular1": "../../_static/mnist_data/mnist1_binary.csv", "tabular2": "../../_static/mnist_data/mnist2_binary.csv", "image": "", } .. GENERATED FROM PYTHON SOURCE LINES 105-119 4. Training the fusion model ๐Ÿ -------------------------------------- Now we're ready to train our model. We're using the :func:`~fusilli.train.train_and_save_models` function to train our model. First we need to create a data module using the :func:`~fusilli.data.prepare_fusion_data` function. This function takes the following parameters: - ``prediction_task``: the type of prediction to be performed. - ``fusion_model``: the fusion model to be trained. - ``data_paths``: the paths to the input data files. - ``output_paths``: the paths to the output directories. Then we pass the data module and the fusion model to the :func:`~fusilli.train.train_and_save_models` function. We're not using checkpointing for this example, so we set ``enable_checkpointing=False``. We're also setting ``show_loss_plot=True`` to plot the loss curves for each fold. .. GENERATED FROM PYTHON SOURCE LINES 119-143 .. code-block:: Python fusion_model = TabularCrossmodalMultiheadAttention print("method_name:", fusion_model.method_name) print("modality_type:", fusion_model.modality_type) print("fusion_type:", fusion_model.fusion_type) dm = prepare_fusion_data(prediction_task=prediction_task, fusion_model=fusion_model, data_paths=data_paths, output_paths=output_paths, kfold=kfold, num_folds=num_folds, batch_size=batch_size) # train and test single_model_list = train_and_save_models( data_module=dm, fusion_model=fusion_model, enable_checkpointing=False, # False for the example notebooks show_loss_plot=True, ) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_001.png :alt: Loss Curves for TabularCrossmodalMultiheadAttention_fold_0 :srcset: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_002.png :alt: Loss Curves for TabularCrossmodalMultiheadAttention_fold_1 :srcset: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_003.png :alt: Loss Curves for TabularCrossmodalMultiheadAttention_fold_2 :srcset: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_004.png :alt: Loss Curves for TabularCrossmodalMultiheadAttention_fold_3 :srcset: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_005.png :alt: Loss Curves for TabularCrossmodalMultiheadAttention_fold_4 :srcset: /auto_examples/training_and_testing/images/sphx_glr_plot_one_model_binary_kfold_005.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none method_name: Tabular Crossmodal multi-head attention modality_type: tabular_tabular fusion_type: attention Training: | | 0/? [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_one_model_binary_kfold.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_one_model_binary_kfold.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_