Source code for fusilli.utils.training_utils

"""
Functions for initialising the pytorch lightning logger and trainer, getting final validation metrics
from trained pytorch lightning models, and various functions for setting checkpoint filenames based
on model, parameters, and user-defined strings.
"""

import os
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch import Trainer

# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.loggers import CSVLogger, WandbLogger
from tqdm import tqdm


[docs] def get_file_suffix_from_dict(extra_log_string_dict): """ Get the extra name string and tags from the extra_log_string_dict. Parameters ---------- extra_log_string_dict : dict Extra string to add to the run name. e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". Returns ------- extra_name_string : str Extra name string to add to the some path name. extra_tags : list List of extra tags to add to the logged run (wandb). """ if extra_log_string_dict is not None: extra_name_string = "" extra_tags = [] for key, value in extra_log_string_dict.items(): extra_name_string += f"_{key}_{str(value)}" extra_tags.append(f"{key}_{str(value)}") else: extra_name_string = "" extra_tags = [] return extra_name_string, extra_tags
[docs] def set_logger(fold, project_name, output_paths, fusion_model, extra_log_string_dict=None, wandb_logging=False): """ Set the logger for the current run. If wandb_logging is True, then the logger is set to WandbLogger, otherwise it is set to CSVLogger and the logs are saved to output_paths["losses"]. Parameters ---------- fold : int or None Fold number. None if not using kfold. project_name : str or None Name of the project. Used for wandb logging. If None, then the project name is set to "fusilli". output_paths : dict Dictionary of output paths for checkpoints, logs, and figures. fusion_model : class Fusion model class. extra_log_string_dict : dict Extra string to add to the run name. e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". Default None. wandb_logging : bool Whether to use wandb logging. If True, then the logger is set to WandbLogger, otherwise it is set to CSVLogger and the logs are saved to output_paths["losses"]. Default False. Returns ------- logger : object Pytorch lightning logger object or CSVLogger object if wandb_logging is False. """ if hasattr(fusion_model, "__name__"): method_name = fusion_model.__name__ else: method_name = fusion_model.__class__.__name__ modality_type = fusion_model.modality_type fusion_type = fusion_model.fusion_type extra_name_string, extra_tags = get_file_suffix_from_dict(extra_log_string_dict) if fold is not None: name = f"{method_name}_fold_{fold}{extra_name_string}" tags = [modality_type, fusion_type, f"fold_{str(fold)}"] + extra_tags else: name = f"{method_name}{extra_name_string}" tags = [modality_type, fusion_type] + extra_tags if wandb_logging: if project_name is None: project_name = "fusilli" logger = WandbLogger( save_dir=os.getcwd() + "/logs", project=project_name, name=name, tags=tags, log_model=True, group=method_name, reinit=True, ) logger.experiment.config["method_name"] = method_name if extra_log_string_dict is not None: for key, value in extra_log_string_dict.items(): logger.experiment.config[key] = value else: # if wandb_logging is False logger = CSVLogger( save_dir=output_paths["losses"], name='', version=name, ) return logger
[docs] def set_checkpoint_name(fusion_model, fold=None, extra_log_string_dict=None): """ Set the checkpoint name for the current run of the main fusion model. Parameters ---------- fusion_model : class Fusion model class. fold : int Fold number. None if not using kfold. extra_log_string_dict : dict Extra string to add to the run name. e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". Default None. Returns ------- checkpoint_filename : str Checkpoint filename. """ extra_name_string, extra_tags = get_file_suffix_from_dict(extra_log_string_dict) if fold is not None: checkpoint_filename = ( fusion_model.__name__ + "_fold_" + str(fold) + extra_name_string + "_{epoch:02d}" ) else: checkpoint_filename = ( str(fusion_model.__name__) + extra_name_string + "_{epoch:02d}" ) return checkpoint_filename
[docs] def get_checkpoint_filenames_for_subspace_models(subspace_method, k=None): """ Get the checkpoint filenames for the subspace models based on the subspace method class and the datamodule that is passed into the subspace method class. Parameters ---------- subspace_method : class Subspace method class. k : int Fold number. None if not using kfold. Default None. Returns ------- checkpoint_filenames : list List of checkpoint filenames. One for each subspace model in the subspace method class. """ if hasattr(subspace_method.datamodule.fusion_model, "__name__"): big_fusion_model_name = subspace_method.datamodule.fusion_model.__name__ else: big_fusion_model_name = ( subspace_method.datamodule.fusion_model.__class__.__name__ ) log_string, _ = get_file_suffix_from_dict( subspace_method.datamodule.extra_log_string_dict ) checkpoint_filenames = [] for subspace_model in subspace_method.subspace_models: if k is not None: checkpoint_filenames.append( "subspace_" + big_fusion_model_name + "_" + subspace_model.__name__ + "_fold_" + str(k) + log_string ) else: checkpoint_filenames.append( "subspace_" + big_fusion_model_name + "_" + subspace_model.__name__ + log_string ) return checkpoint_filenames
[docs] def get_checkpoint_filename_for_trained_fusion_model( checkpoint_dir, model, checkpoint_file_suffix, fold=None ): """ Gets the checkpoint filename for the trained fusion model using the model object. Checkpoints should follow the naming convention: * fusion_model_name_fold_k_{checkpoint_file_suffix} if fold is not None * fusion_model_name_{checkpoint_file_suffix} if fold is None Parameters ---------- checkpoint_dir : str Path to the directory containing the checkpoints. model : BaseModel BaseModel model object instance. checkpoint_file_suffix : str Checkpoint file suffix. fold : int Fold number. None if not using kfold. Default None. Returns ------- checkpoint_filename : str Checkpoint filename. """ if checkpoint_file_suffix is None: checkpoint_file_suffix = "" if fold is None: ckpt_path_beginning = model.model.__class__.__name__ + checkpoint_file_suffix else: ckpt_path_beginning = ( model.model.__class__.__name__ + "_fold_" + str(fold) + checkpoint_file_suffix ) result = [ filename for filename in os.listdir(checkpoint_dir) if filename.startswith(ckpt_path_beginning) ] if len(result) == 0: raise ValueError( f"Could not find checkpoint file with name {ckpt_path_beginning} in {checkpoint_dir}." ) elif len(result) > 1: # if the model is a subspace method, then we need to check if the checkpoint file is for the subspace model # or the big fusion model # TODO add this check raise ValueError( f"Found multiple checkpoint files with name {ckpt_path_beginning} in {checkpoint_dir}." ) else: checkpoint_filename = result[0] checkpoint_filename = os.path.join( checkpoint_dir, checkpoint_filename ) return checkpoint_filename
[docs] class LitProgressBar(TQDMProgressBar): """ Custom progress bar for pytorch lightning trainer. This is to disable the progress bar for validation. Parameters ---------- TQDMProgressBar : object Pytorch lightning TQDMProgressBar object. """
[docs] def init_validation_tqdm(self): bar = tqdm( disable=True, ) return bar
[docs] def init_trainer( logger, output_paths, max_epochs=1000, enable_checkpointing=True, checkpoint_filename=None, own_early_stopping_callback=None, training_modifications=None, ): """ Initialise the pytorch lightning trainer object. Parameters ---------- logger : object Pytorch lightning logger object. output_paths : dict Dictionary of output paths for checkpoints, losses, and figures. max_epochs : int Maximum number of epochs. Default 1000. enable_checkpointing : bool Whether to enable checkpointing. If True, then checkpoints will be saved. We use False for the example notebooks in the repository/documentation. Default True. checkpoint_filename : str Checkpoint filename. Default None if using default checkpointing. own_early_stopping_callback : object Own early stopping callback object. Default None to use default early stopping callback. If you want to use your own early stopping callback, then you need to define it in the main training script and pass it here or pass into the datamodule object and then it will read it from there. training_modifications : dict Dictionary of training modifications. Used to modify the training process. Keys could be "accelerator", "devices", "learning_rate" Returns ------- trainer : pl.Trainer Pytorch lightning trainer object. """ if own_early_stopping_callback is not None: early_stop_callback = own_early_stopping_callback else: early_stop_callback = EarlyStopping( monitor="val_loss", min_delta=0.00, patience=15, verbose=False, mode="min", ) bar = LitProgressBar() callbacks_list = [early_stop_callback, bar] if checkpoint_filename is not None: checkpoint_callback = ModelCheckpoint( filename=checkpoint_filename, dirpath=output_paths["checkpoints"], enable_version_counter=False # overwrites files with same name ) callbacks_list.append(checkpoint_callback) # check if accelerator and devices are in training_modifications given by user accelerator = "cpu" # default devices = 1 # default if training_modifications is not None: if "accelerator" in training_modifications.keys(): accelerator = training_modifications["accelerator"] if "devices" in training_modifications.keys(): devices = training_modifications["devices"] if logger is None: # if logger is None (not even a CSVLogger), then we don't want to log anything # logger must be set to False in the trainer logger = False trainer = Trainer( max_epochs=max_epochs, num_sanity_val_steps=0, accelerator=accelerator, devices=devices, callbacks=callbacks_list, log_every_n_steps=2, # default_root_dir=run_dir, logger=logger, enable_checkpointing=enable_checkpointing, ) return trainer
[docs] def get_final_val_metrics(trainer): """ Get the final validation metrics from the trainer object. Parameters ---------- trainer : pl.Trainer Pytorch lightning trainer object. Returns ------- metric1 : float Final validation metric 1. metric2 : float Final validation metric 2. """ metric_names = trainer.model.metric_names_list # raise error if trainer.callback_metrics is empty if len(trainer.callback_metrics) == 0: raise ValueError("trainer.callback_metrics is empty.") metrics = [] # raise error if any of the metric_names are not in trainer.callback_metrics for metric_name in metric_names: if f"{metric_name}_val" not in trainer.callback_metrics.keys(): raise ValueError( f"{metric_name}_val not in trainer.callback_metrics.keys()." ) metrics.append(trainer.callback_metrics[f"{metric_name}_val"].item()) return metrics