Source code for fusilli.train

"""
Contains the train_and_test function: trains and tests a model and, if k_fold trained, a fold.
"""

from fusilli.fusionmodels.base_model import BaseModel
from fusilli.utils.training_utils import (
    get_final_val_metrics,
    init_trainer,
    set_logger,
    set_checkpoint_name,
)
import wandb
from fusilli.utils import model_modifier
from lightning.pytorch.loggers import CSVLogger
from fusilli.utils.csv_loss_plotter import plot_loss_curve


[docs] def train_and_test( data_module, k, fusion_model, kfold, extra_log_string_dict=None, layer_mods=None, max_epochs=1000, enable_checkpointing=True, show_loss_plot=False, wandb_logging=False, project_name=None, training_modifications=None, metrics_list=None, ): """ Trains and tests a model and, if k_fold trained, a fold. Parameters ---------- data_module : pytorch lightning data module Data module. Contains the train and val dataloaders. k : int Fold number. fusion_model : class Fusion model class. kfold : bool Whether to train a kfold model. extra_log_string_dict : dict Dictionary of extra log strings. Extra string to add to the run name during logging. e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". layer_mods : dict Dictionary of layer modifications. Used to modify the architecture of the model. Input format {"model": {"layer_group": "modification"}, ...}. e.g. {"TabularCrossmodalAttention": {"mod1_layers": new mod 1 layers nn.ModuleDict}} Default None. max_epochs : int Maximum number of epochs. Default 1000. enable_checkpointing : bool Whether to enable checkpointing. Default True. show_loss_plot : bool Whether to show the loss plot. Default False. If True, the loss plot will be shown after training with ``plt.show()`` If False, the loss plot will be saved to the log directory. wandb_logging : bool Whether to log to Weights and Biases. Default False. project_name : str or None Name of the project to log to in Weights and Biases. Default None. If None, the project name will be called "fusilli". training_modifications : dict Dictionary of training modifications. Used to modify the training process. Keys could be "accelerator", "devices" metrics_list : list List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models' performances. Length must be 2 or more. Returns ------- model : pytorch lightning model Trained model. trainer : pytorch lightning trainer Trained trainer. metric_1 : float Metric 1 (depends on metric_name_list and prediction_task. metric_2 : float Metric 2 (depends on metric_name_list and prediction_task. val_reals : list List of validation real values. val_preds : list List of validation predicted values. """ # define checkpoint filename if kfold: if enable_checkpointing: checkpoint_filename = set_checkpoint_name( fusion_model, fold=k, extra_log_string_dict=extra_log_string_dict, ) else: checkpoint_filename = None if fusion_model.fusion_type == "graph": data_module = data_module[k] output_paths = data_module.output_paths train_dataloader = data_module.train_dataloader() val_dataloader = data_module.val_dataloader() else: output_paths = data_module.output_paths train_dataloader = data_module.train_dataloader(fold_idx=k) val_dataloader = data_module.val_dataloader(fold_idx=k) else: if enable_checkpointing: checkpoint_filename = set_checkpoint_name( fusion_model=fusion_model, extra_log_string_dict=extra_log_string_dict, ) else: checkpoint_filename = None train_dataloader = data_module.train_dataloader() val_dataloader = data_module.val_dataloader() output_paths = data_module.output_paths logger = set_logger( fold=k, project_name=project_name, output_paths=output_paths, fusion_model=fusion_model, extra_log_string_dict=extra_log_string_dict, wandb_logging=wandb_logging, ) # set logger trainer = init_trainer( logger, output_paths=output_paths, max_epochs=max_epochs, enable_checkpointing=enable_checkpointing, checkpoint_filename=checkpoint_filename, own_early_stopping_callback=data_module.own_early_stopping_callback, training_modifications=training_modifications, ) # init trainer # initialise model with pytorch lightning framework, hence pl_model pl_model = BaseModel( fusion_model( prediction_task=data_module.prediction_task, data_dims=data_module.data_dims, # data_dims is a list of tuples multiclass_dimensions=data_module.multiclass_dimensions, ), metrics_list=metrics_list, ) # modify model architecture if layer_mods is not None if layer_mods is not None: pl_model.model = model_modifier.modify_model_architecture( pl_model.model, layer_mods ) # graph-methods use masks to select train and val nodes rather than train and val dataloaders if pl_model.model.fusion_type == "graph": pl_model.train_mask = data_module.input_train_nodes pl_model.val_mask = data_module.input_test_nodes # fit model trainer.fit( pl_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ) # test model trainer.validate(pl_model, val_dataloader) # get final validation metrics final_val_metrics = get_final_val_metrics(trainer) pl_model.final_val_metrics = final_val_metrics # if logger is CSVLogger, plot loss curve if isinstance(logger, CSVLogger): plot_loss_curve( figures_path=output_paths["figures"], logger=logger, show=show_loss_plot ) return pl_model
def _store_trained_model(trained_model, trained_models_dict): """ Stores the trained model to a dictionary. If model type is already in dictionary (e.g. if it's a kfold model), append the model to the list of models. Parameters ---------- trained_model : pytorch lightning model Trained model. trained_models_dict : dict Dictionary of trained models. Returns ------- trained_models_dict : dict Dictionary of trained models. """ # get model name classname = trained_model.model.__class__.__name__ # if model is already in dictionary, we're training a kfold model if classname in trained_models_dict: # if the model is already a list, append the new model to the list # this is for when we're training a kfold model and on the third fold onwards if isinstance(trained_models_dict[classname], list): trained_models_dict[classname].append(trained_model) # if the model is not a list, make it a list and append the new model to the list # this is for when we're training a kfold model and on the second fold else: trained_models_dict[classname] = [ trained_models_dict[classname], trained_model, ] else: # If the model is not in the dictionary, add it as a new key-value pair # This is for when we're training a single model with train/test split or # when we're training a kfold model and on the first fold trained_models_dict[classname] = [trained_model] return trained_models_dict
[docs] def train_and_save_models( data_module, fusion_model, wandb_logging=False, extra_log_string_dict=None, layer_mods=None, max_epochs=1000, enable_checkpointing=True, show_loss_plot=False, project_name=None, metrics_list=None, training_modifications=None, ): """ Trains/tests the model and saves the trained model to a dictionary for further analysis. If the training type is kfold, it will train and test the model for each fold and store the trained models in a list. Parameters ---------- data_module : pytorch lightning data module Data module. fusion_model : class Fusion model class. wandb_logging : bool Whether to log to wandb. Default False. extra_log_string_dict : dict Dictionary of extra log strings. Extra string to add to the run name during logging. e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". layer_mods : dict Dictionary of layer modifications. Used to modify the architecture of the model. Input format {"model": {"layer_group": "modification"}, ...}. e.g. {"TabularCrossmodalAttention": {"mod1_layers": new mod 1 layers nn.ModuleDict}} Default None. max_epochs : int Maximum number of epochs. Default 1000. enable_checkpointing : bool Whether to enable checkpointing. Default True. show_loss_plot : bool Whether to show the loss plot. Default False. project_name : str or None Name of the project to log to in Weights and Biases. Default None. If None, the project name will be called "fusilli". metrics_list : list List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models' performances. Length must be 2 or more. training_modifications : dict Dictionary of training modifications. Used to modify the training process. Keys could be "accelerator", "devices" Returns ------- trained_models_list : list List of trained models. Length of list is 1 if training type is train/test split. Length of list is num_k if training type is kfold. """ # trained_models_dict = {} trained_models_list = [] # checking to see if our model is a kfold model if hasattr(data_module, "num_folds") and data_module.num_folds is not None: kfold = True num_folds = data_module.num_folds elif isinstance(data_module, list): if ( hasattr(data_module[0], "num_folds") and data_module[0].num_folds is not None ): kfold = True num_folds = data_module[0].num_folds else: kfold = False num_folds = None if kfold: for k in range(num_folds): trained_model = train_and_test( data_module=data_module, k=k, fusion_model=fusion_model, kfold=kfold, extra_log_string_dict=extra_log_string_dict, layer_mods=layer_mods, max_epochs=max_epochs, enable_checkpointing=enable_checkpointing, show_loss_plot=show_loss_plot, wandb_logging=wandb_logging, project_name=project_name, metrics_list=metrics_list, training_modifications=training_modifications, ) trained_models_list.append(trained_model) if wandb_logging: wandb.finish() else: trained_model = train_and_test( data_module=data_module, k=None, fusion_model=fusion_model, kfold=kfold, extra_log_string_dict=extra_log_string_dict, layer_mods=layer_mods, max_epochs=max_epochs, enable_checkpointing=enable_checkpointing, show_loss_plot=show_loss_plot, wandb_logging=wandb_logging, project_name=project_name, metrics_list=metrics_list, training_modifications=training_modifications, ) trained_models_list.append(trained_model) if wandb_logging: wandb.finish() return trained_models_list