fusilli.train

Contains the train_and_test function: trains and tests a model and, if k_fold trained, a fold.

Functions

_store_trained_model(trained_model, ...)

Stores the trained model to a dictionary.

train_and_save_models(data_module, fusion_model)

Trains/tests the model and saves the trained model to a dictionary for further analysis.

train_and_test(data_module, k, fusion_model, ...)

Trains and tests a model and, if k_fold trained, a fold.

train_and_save_models(data_module, fusion_model, wandb_logging=False, extra_log_string_dict=None, layer_mods=None, max_epochs=1000, enable_checkpointing=True, show_loss_plot=False, project_name=None, metrics_list=None)[source]

Trains/tests the model and saves the trained model to a dictionary for further analysis. If the training type is kfold, it will train and test the model for each fold and store the trained models in a list.

Parameters:
  • data_module (pytorch lightning data module) – Data module.

  • fusion_model (class) – Fusion model class.

  • wandb_logging (bool) – Whether to log to wandb. Default False.

  • extra_log_string_dict (dict) – Dictionary of extra log strings. Extra string to add to the run name during logging. e.g. if you’re running the same model with different hyperparameters, you can add the hyperparameters. Input format {“name”: “value”}. In the run name, the extra string will be added as “name_value”. And a tag will be added as “name_value”.

  • layer_mods (dict) – Dictionary of layer modifications. Used to modify the architecture of the model. Input format {“model”: {“layer_group”: “modification”}, …}. e.g. {“TabularCrossmodalAttention”: {“mod1_layers”: new mod 1 layers nn.ModuleDict}} Default None.

  • max_epochs (int) – Maximum number of epochs. Default 1000.

  • enable_checkpointing (bool) – Whether to enable checkpointing. Default True.

  • show_loss_plot (bool) – Whether to show the loss plot. Default False.

  • project_name (str or None) – Name of the project to log to in Weights and Biases. Default None. If None, the project name will be called “fusilli”.

  • metrics_list (list) – List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models’ performances. Length must be 2 or more.

Returns:

trained_models_list – List of trained models. Length of list is 1 if training type is train/test split. Length of list is num_k if training type is kfold.

Return type:

list

train_and_test(data_module, k, fusion_model, kfold, extra_log_string_dict=None, layer_mods=None, max_epochs=1000, enable_checkpointing=True, show_loss_plot=False, wandb_logging=False, project_name=None, training_modifications=None, metrics_list=None)[source]

Trains and tests a model and, if k_fold trained, a fold.

Parameters:
  • data_module (pytorch lightning data module) – Data module. Contains the train and val dataloaders.

  • k (int) – Fold number.

  • fusion_model (class) – Fusion model class.

  • kfold (bool) – Whether to train a kfold model.

  • extra_log_string_dict (dict) –

    Dictionary of extra log strings. Extra string to add to the run name during logging.

    e.g. if you’re running the same model with different hyperparameters, you can add the hyperparameters. Input format {“name”: “value”}. In the run name, the extra string will be added as “name_value”. And a tag will be added as “name_value”.

  • layer_mods (dict) – Dictionary of layer modifications. Used to modify the architecture of the model. Input format {“model”: {“layer_group”: “modification”}, …}. e.g. {“TabularCrossmodalAttention”: {“mod1_layers”: new mod 1 layers nn.ModuleDict}} Default None.

  • max_epochs (int) – Maximum number of epochs. Default 1000.

  • enable_checkpointing (bool) – Whether to enable checkpointing. Default True.

  • show_loss_plot (bool) – Whether to show the loss plot. Default False. If True, the loss plot will be shown after training with plt.show() If False, the loss plot will be saved to the log directory.

  • wandb_logging (bool) – Whether to log to Weights and Biases. Default False.

  • project_name (str or None) – Name of the project to log to in Weights and Biases. Default None. If None, the project name will be called “fusilli”.

  • training_modifications (dict) – Dictionary of training modifications. Used to modify the training process. Keys could be “accelerator”, “devices”

  • metrics_list (list) – List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models’ performances. Length must be 2 or more.

Returns:

  • model (pytorch lightning model) – Trained model.

  • trainer (pytorch lightning trainer) – Trained trainer.

  • metric_1 (float) – Metric 1 (depends on metric_name_list and prediction_task.

  • metric_2 (float) – Metric 2 (depends on metric_name_list and prediction_task.

  • val_reals (list) – List of validation real values.

  • val_preds (list) – List of validation predicted values.