fusilli.utils.training_utils

Functions for initialising the pytorch lightning logger and trainer, getting final validation metrics from trained pytorch lightning models, and various functions for setting checkpoint filenames based on model, parameters, and user-defined strings.

Functions

get_checkpoint_filename_for_trained_fusion_model(...)

Gets the checkpoint filename for the trained fusion model using the model object.

get_checkpoint_filenames_for_subspace_models(...)

Get the checkpoint filenames for the subspace models based on the subspace method class and the datamodule that is passed into the subspace method class.

get_file_suffix_from_dict(extra_log_string_dict)

Get the extra name string and tags from the extra_log_string_dict.

get_final_val_metrics(trainer)

Get the final validation metrics from the trainer object.

init_trainer(logger, output_paths[, ...])

Initialise the pytorch lightning trainer object.

set_checkpoint_name(fusion_model[, fold, ...])

Set the checkpoint name for the current run of the main fusion model.

set_logger(fold, project_name, output_paths, ...)

Set the logger for the current run.

Classes

LitProgressBar([refresh_rate, ...])

Custom progress bar for pytorch lightning trainer.

class LitProgressBar(refresh_rate: int = 1, process_position: int = 0, leave: bool = False)[source]

Bases: TQDMProgressBar

Custom progress bar for pytorch lightning trainer. This is to disable the progress bar for validation.

Parameters:

TQDMProgressBar (object) – Pytorch lightning TQDMProgressBar object.

init_validation_tqdm()[source]

Override this to customize the tqdm bar for validation.

get_checkpoint_filename_for_trained_fusion_model(checkpoint_dir, model, checkpoint_file_suffix, fold=None)[source]

Gets the checkpoint filename for the trained fusion model using the model object.

Checkpoints should follow the naming convention:

  • fusion_model_name_fold_k_{checkpoint_file_suffix} if fold is not None

  • fusion_model_name_{checkpoint_file_suffix} if fold is None

Parameters:
  • checkpoint_dir (str) – Path to the directory containing the checkpoints.

  • model (BaseModel) – BaseModel model object instance.

  • checkpoint_file_suffix (str) – Checkpoint file suffix.

  • fold (int) – Fold number. None if not using kfold. Default None.

Returns:

checkpoint_filename – Checkpoint filename.

Return type:

str

get_checkpoint_filenames_for_subspace_models(subspace_method, k=None)[source]

Get the checkpoint filenames for the subspace models based on the subspace method class and the datamodule that is passed into the subspace method class.

Parameters:
  • subspace_method (class) – Subspace method class.

  • k (int) – Fold number. None if not using kfold. Default None.

Returns:

checkpoint_filenames – List of checkpoint filenames. One for each subspace model in the subspace method class.

Return type:

list

get_file_suffix_from_dict(extra_log_string_dict)[source]

Get the extra name string and tags from the extra_log_string_dict.

Parameters:

extra_log_string_dict (dict) – Extra string to add to the run name. e.g. if you’re running the same model with different hyperparameters, you can add the hyperparameters. Input format {“name”: “value”}. In the run name, the extra string will be added as “name_value”. And a tag will be added as “name_value”.

Returns:

  • extra_name_string (str) – Extra name string to add to the some path name.

  • extra_tags (list) – List of extra tags to add to the logged run (wandb).

get_final_val_metrics(trainer)[source]

Get the final validation metrics from the trainer object.

Parameters:

trainer (pl.Trainer) – Pytorch lightning trainer object.

Returns:

  • metric1 (float) – Final validation metric 1.

  • metric2 (float) – Final validation metric 2.

init_trainer(logger, output_paths, max_epochs=1000, enable_checkpointing=True, checkpoint_filename=None, own_early_stopping_callback=None, training_modifications=None)[source]

Initialise the pytorch lightning trainer object.

Parameters:
  • logger (object) – Pytorch lightning logger object.

  • output_paths (dict) – Dictionary of output paths for checkpoints, losses, and figures.

  • max_epochs (int) – Maximum number of epochs. Default 1000.

  • enable_checkpointing (bool) – Whether to enable checkpointing. If True, then checkpoints will be saved. We use False for the example notebooks in the repository/documentation. Default True.

  • checkpoint_filename (str) – Checkpoint filename. Default None if using default checkpointing.

  • own_early_stopping_callback (object) – Own early stopping callback object. Default None to use default early stopping callback. If you want to use your own early stopping callback, then you need to define it in the main training script and pass it here or pass into the datamodule object and then it will read it from there.

  • training_modifications (dict) – Dictionary of training modifications. Used to modify the training process. Keys could be “accelerator”, “devices”, “learning_rate”

Returns:

trainer – Pytorch lightning trainer object.

Return type:

pl.Trainer

set_checkpoint_name(fusion_model, fold=None, extra_log_string_dict=None)[source]

Set the checkpoint name for the current run of the main fusion model.

Parameters:
  • fusion_model (class) – Fusion model class.

  • fold (int) – Fold number. None if not using kfold.

  • extra_log_string_dict (dict) – Extra string to add to the run name. e.g. if you’re running the same model with different hyperparameters, you can add the hyperparameters. Input format {“name”: “value”}. In the run name, the extra string will be added as “name_value”. And a tag will be added as “name_value”. Default None.

Returns:

checkpoint_filename – Checkpoint filename.

Return type:

str

set_logger(fold, project_name, output_paths, fusion_model, extra_log_string_dict=None, wandb_logging=False)[source]

Set the logger for the current run. If wandb_logging is True, then the logger is set to WandbLogger, otherwise it is set to CSVLogger and the logs are saved to output_paths[“losses”].

Parameters:
  • fold (int or None) – Fold number. None if not using kfold.

  • project_name (str or None) – Name of the project. Used for wandb logging. If None, then the project name is set to “fusilli”.

  • output_paths (dict) – Dictionary of output paths for checkpoints, logs, and figures.

  • fusion_model (class) – Fusion model class.

  • extra_log_string_dict (dict) – Extra string to add to the run name. e.g. if you’re running the same model with different hyperparameters, you can add the hyperparameters. Input format {“name”: “value”}. In the run name, the extra string will be added as “name_value”. And a tag will be added as “name_value”. Default None.

  • wandb_logging (bool) – Whether to use wandb logging. If True, then the logger is set to WandbLogger, otherwise it is set to CSVLogger and the logs are saved to output_paths[“losses”]. Default False.

Returns:

logger – Pytorch lightning logger object or CSVLogger object if wandb_logging is False.

Return type:

object