Source code for fusilli.data

"""
Data loading classes for multimodal and unimodal data.
This file contains functions and classes for loading the data for the different modalities, and
in training the subspace methods on the data (if the subspace methods need pre-training).
Train/test splits and k-fold cross validation are also implemented here.
"""

# imports
import pandas as pd
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data.lightning import LightningNodeData
from fusilli.utils import model_modifier


[docs] def downsample_img_batch(imgs, output_size): """ Downsamples a batch of images to a specified size. Parameters ---------- imgs : array-like Batch of images. Shape (batch_size, channels, height, width) or (batch_size, channels, height, width, depth) for 3D images. output_size : tuple Size to downsample the images to (height, width) or (height, width, depth) for 3D images. Do not put the batch_size dimension in the tuple. If None, no downsampling is performed Returns ------- downsampled_img : array-like Downsampled image. """ if output_size is None: # if no downsampling return imgs # if number of output_size dims is equal to number of image dims - 3 # (i.e. if output_size is (64) and image is (16, 3, 128, 128)) # or output_size is (64, 64) and image is (16, 3, 128, 128, 128)) if len(output_size) == imgs.dim() - 3: raise ValueError( f"output_size must have {imgs.dim() - 3} dimensions, not {len(output_size)}.\ Make sure to exclude the channel dimension so output_size looks like\ (height, width) for 2D or (height, width, depth) for 3D." ) # if output_size has a negative value if any([i < 0 for i in output_size]): raise ValueError( f"output_size must not have negative values, but got {output_size}." ) # if output_size has more than 3 dimensions if len(output_size) > 3: raise ValueError( f"output_size must have 2 or 3 dimensions, not {len(output_size)}." ) # if output_size has more than 2 dimensions and image is 2D if len(output_size) > 2 and imgs.dim() == 4: raise ValueError( f"output_size must have 2 dimensions, not {len(output_size)} because img_dims indicates a 2D image." ) # if output_size has more than 4 dimensions and image is 3D if len(output_size) > 3 and imgs.dim() == 5: raise ValueError( f"output_size must have 3 dimensions, not {len(output_size)} because img_dims indicates a 3D image." ) # if output_size is larger than image dimensions if any([i > j for i, j in zip(output_size, imgs.shape[2:])]): raise ValueError( f"output_size must be smaller than image dimensions, but got {output_size} and " f"image dimensions {imgs.shape[2:]}" ) downsampled_img = F.interpolate(imgs, size=output_size, mode="nearest") return downsampled_img
[docs] class CustomDataset(Dataset): """ Custom dataset class for multimodal data. Attributes ---------- multimodal_flag : bool Flag for multimodal data. True if multimodal, False if unimodal. dataset1 : tensor Tensor of predictive features for modality 1. dataset2 : tensor Tensor of predictive features for modality 2. dataset : tensor Tensor of predictive features for uni-modal data. labels : tensor Tensor of labels. """
[docs] def __init__(self, pred_features, labels): """ Parameters ---------- pred_features : list or tensor List of tensors or tensor of predictive features (i.e. tabular or image data without labels). labels : dataframe Dataframe of labels (column name must be "prediction_label"). Raises ------ ValueError If pred_features is not a list or tensor. """ # if pred_features is a list: it's multimodal data # only 2 modalities are supported currently if isinstance(pred_features, list): self.multimodal_flag = True self.dataset1 = pred_features[0].float() self.dataset2 = pred_features[1].float() # if pred_features is a tensor: it's unimodal data elif isinstance(pred_features, torch.Tensor): self.multimodal_flag = False self.dataset = pred_features.float() else: raise ValueError( f"pred_features must be a list or a tensor, not {type(pred_features)}" ) # convert labels to tensor and correct dtype label_type = labels[["prediction_label"]].values.dtype self.labels = torch.tensor(labels[["prediction_label"]].to_numpy().reshape(-1)) if label_type == "int64": self.labels = self.labels.long() else: self.labels = self.labels.float()
def __len__(self): """ Returns the length of the dataset. Returns ------- int Length of the dataset. """ if self.multimodal_flag: return len(self.dataset1) else: return len(self.dataset) def __getitem__(self, idx): """ Returns the item at the specified index. Parameters ---------- idx : int Index of the item to return. """ if self.multimodal_flag: return self.dataset1[idx], self.dataset2[idx], self.labels[idx] else: return self.dataset[idx], self.labels[idx]
[docs] class LoadDatasets: """ Class for loading the different datasets for the different modalities. Attributes ---------- tabular1_source : str Source csv file for tabular1 data. tabular2_source : str Source csv file for tabular2 data. img_source : str Source torch file for image data. image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) """
[docs] def __init__(self, sources, img_downsample_dims=None): """ Parameters ---------- sources : list List of source csv files. [tabular1_source, tabular2_source, img_source] img_downsample_dims : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) Raises ------ ValueError If sources is not a list. ValueError If the CSVs do not have the right columns or if the index column is not named "ID". """ self.tabular1_source, self.tabular2_source, self.img_source = sources self.image_downsample_size = ( img_downsample_dims # can choose own image size here ) # read in the csv files and raise errors if they don't have the right columns # or if the index column is not named "ID" tab1_df = pd.read_csv(self.tabular1_source) if "ID" not in tab1_df.columns: raise ValueError("The CSV must have an index column named 'ID'.") if "prediction_label" not in tab1_df.columns: raise ValueError( "The CSV must have a label column named 'prediction_label'." ) # if tabular2_source exists, check it has the right columns if self.tabular2_source != "": tab2_df = pd.read_csv(self.tabular2_source) if "ID" not in tab2_df.columns: raise ValueError("The CSV must have an index column named 'ID'.") if "prediction_label" not in tab2_df.columns: raise ValueError( "The CSV must have a label column named 'prediction_label'." )
[docs] def load_tabular1(self): """ Loads the tabular1-only dataset Returns ------ dataset (tensor): tensor of predictive features data_dims (list): list of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ tab_df = pd.read_csv(self.tabular1_source) tab_df.set_index("ID", inplace=True) pred_features = torch.Tensor(tab_df.drop(columns=["prediction_label"]).values) prediction_label = tab_df[["prediction_label"]] dataset = CustomDataset(pred_features, prediction_label) mod1_dim = pred_features.shape[1] return dataset, [mod1_dim, None, None]
[docs] def load_tabular2(self): """ Loads the tabular2-only dataset Returns ------ dataset (tensor): tensor of predictive features data_dims (list): list of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ tab_df = pd.read_csv(self.tabular2_source) tab_df.set_index("ID", inplace=True) pred_features = torch.Tensor(tab_df.drop(columns=["prediction_label"]).values) prediction_label = tab_df[["prediction_label"]] dataset = CustomDataset(pred_features, prediction_label) mod2_dim = pred_features.shape[1] return dataset, [None, mod2_dim, None]
[docs] def load_img(self): """ Loads the image-only dataset Returns ------ dataset (tensor): tensor of predictive features data_dims (list): list of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ all_scans = torch.load(self.img_source) all_scans_ds = downsample_img_batch(all_scans, self.image_downsample_size) # get the labels from the tabular1 dataset label_df = pd.read_csv(self.tabular1_source) label_df.set_index("ID", inplace=True) prediction_label = label_df[["prediction_label"]] dataset = CustomDataset(all_scans_ds, prediction_label) img_dim = list(all_scans_ds.shape[2:]) # not including batch size or channels return dataset, [None, None, img_dim]
[docs] def load_tabular_tabular(self): """ Loads the tabular1 and tabular2 multimodal dataset Returns ------ dataset (tensor): tensor of predictive features data_dims (list): list of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ tab1_df = pd.read_csv(self.tabular1_source) tab2_df = pd.read_csv(self.tabular2_source) tab1_df.set_index("ID", inplace=True) tab2_df.set_index("ID", inplace=True) tab1_pred_features = torch.Tensor( tab1_df.drop(columns=["prediction_label"]).values ) tab2_pred_features = torch.Tensor( tab2_df.drop(columns=["prediction_label"]).values ) prediction_label = tab1_df[["prediction_label"]] dataset = CustomDataset( [tab1_pred_features, tab2_pred_features], prediction_label ) mod1_dim = tab1_pred_features.shape[1] mod2_dim = tab2_pred_features.shape[1] return dataset, [mod1_dim, mod2_dim, None]
[docs] def load_tab_and_img(self): """ Loads the tabular1 and image multimodal dataset. Returns ------ dataset (tensor): tensor of predictive features data_dims (list): list of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ tab1_df = pd.read_csv(self.tabular1_source) tab1_df.set_index("ID", inplace=True) tab1_features = torch.Tensor(tab1_df.drop(columns=["prediction_label"]).values) label_df = tab1_df[["prediction_label"]] imgs = torch.load(self.img_source) imgs = downsample_img_batch(imgs, self.image_downsample_size) dataset = CustomDataset([tab1_features, imgs], label_df) mod1_dim = tab1_features.shape[1] img_dim = list(imgs.shape[2:]) # not including batch size or channels return dataset, [mod1_dim, None, img_dim]
[docs] class TrainTestDataModule(pl.LightningDataModule): """ Custom pytorch lightning datamodule class for the different modalities. Attributes ---------- sources : list List of source csv files. [Tabular1, Tabular2, Image] modality_methods : dict Dictionary of methods for loading the different modalities. fusion_model : class fusion model class. e.g. TabularCrossmodalAttention. output_paths : dict Dictionary of output paths for saving the checkpoints, figures, and the losses. batch_size : int Batch size (default 8). test_size : float Fraction of data to use for testing (default 0.2). prediction_task : str Prediction type (binary, multiclass, or regression). multiclass_dimensions : int Number of classes for multiclass prediction (default None). subspace_method : class Subspace method class (default None) (only for subspace methods). layer_mods : dict Dictionary of layer modifications to make to the subspace method. (default None) max_epochs : int Maximum number of epochs to train subspace methods for. (default 1000) dataset : tensor Tensor of predictive features. Created in prepare_data(). data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim]. Created in prepare_data(). train_dataset : tensor Tensor of predictive features for training. Created in setup(). test_dataset : tensor Tensor of predictive features for testing. Created in setup(). subspace_method_train : class Subspace method class trained (only for subspace methods). own_early_stopping_callback : pytorch_lightning.callbacks.EarlyStopping Early stopping callback class. num_workers : int Number of workers for the dataloader (default 0). test_indices : list List of indices to use for testing (default None). If None, the test indices are randomly selected using the test_size parameter. kwargs : dict Dictionary of extra arguments for the subspace method class. """
[docs] def __init__( self, fusion_model, sources, output_paths, prediction_task, batch_size, test_size, multiclass_dimensions, subspace_method=None, image_downsample_size=None, layer_mods=None, max_epochs=1000, extra_log_string_dict=None, own_early_stopping_callback=None, num_workers=0, test_indices=None, kwargs=None, ): """ Parameters ---------- fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". sources : list List of source csv files. output_paths : dict Dictionary of output paths for saving the checkpoints, figures, and the losses. prediction_task : str Prediction task (binary, multiclass, regression). batch_size : int Batch size (default 8). test_size : float Fraction of data to use for testing (default 0.2). multiclass_dimensions : int Number of classes for multiclass prediction (default None). subspace_method : class Subspace method class (default None) (only for subspace methods). image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) layer_mods : dict Dictionary of layer modifications to make to the subspace method. (default None) max_epochs : int Maximum number of epochs to train subspace methods for. (default 1000) extra_log_string_dict : dict Dictionary of extra strings to add to the log. own_early_stopping_callback : pytorch_lightning.callbacks.EarlyStopping Early stopping callback class (default None). num_workers : int Number of workers for the dataloader (default 0). test_indices : list List of indices to use for testing (default None). If None, the test indices are randomly selected using the test_size parameter. kwargs : dict Dictionary of extra arguments for the subspace method class. """ super().__init__() self.sources = sources self.output_paths = output_paths self.extra_log_string_dict = extra_log_string_dict self.modality_methods = { "tabular1": LoadDatasets(self.sources, image_downsample_size).load_tabular1, "tabular2": LoadDatasets(self.sources, image_downsample_size).load_tabular2, "img": LoadDatasets(self.sources, image_downsample_size).load_img, "tabular_tabular": LoadDatasets( self.sources, image_downsample_size ).load_tabular_tabular, "tabular_image": LoadDatasets( self.sources, image_downsample_size ).load_tab_and_img, } self.fusion_model = fusion_model self.modality_type = self.fusion_model.modality_type self.batch_size = batch_size self.test_size = test_size self.prediction_task = prediction_task if self.prediction_task == "multiclass": self.multiclass_dimensions = multiclass_dimensions else: self.multiclass_dimensions = None self.subspace_method = subspace_method self.layer_mods = layer_mods self.max_epochs = max_epochs self.own_early_stopping_callback = own_early_stopping_callback self.num_workers = num_workers self.test_indices = test_indices self.kwargs = kwargs
[docs] def prepare_data(self): """ Loads the data with LoadDatasets class Returns ------ dataset : tensor Tensor of predictive features. data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ self.dataset, self.data_dims = self.modality_methods[self.modality_type]()
[docs] def setup( self, checkpoint_path=None, ): """ Splits the data into train and test sets, and runs the subspace method if specified. If checkpoint_path is specified, the subspace method is loaded from the checkpoint and not trained. Attributes ---------- checkpoint_path : str Path to the checkpoint file for the subspace method (default None). Returns ------ train_dataloader : dataloader Dataloader for training. val_dataloader : dataloader Dataloader for validation. """ # split the dataset into train and test sets if self.test_indices is None: [self.train_dataset, self.test_dataset] = torch.utils.data.random_split( self.dataset, [1 - self.test_size, self.test_size] ) else: self.test_dataset = torch.utils.data.Subset(self.dataset, self.test_indices) self.train_dataset = torch.utils.data.Subset( self.dataset, list(set(range(len(self.dataset))) - set(self.test_indices)), ) if self.subspace_method is not None: # if subspace method is specified if ( checkpoint_path is None ): # if no checkpoint path specified, train the subspace method self.subspace_method_train = self.subspace_method( datamodule=self, max_epochs=self.max_epochs, k=None, train_subspace=True, ) # modify the subspace method architecture if specified if self.layer_mods is not None: self.subspace_method_train = ( model_modifier.modify_model_architecture( self.subspace_method_train, self.layer_mods, ) ) # train the subspace method and convert train dataset to the latent space train_latents, train_labels = self.subspace_method_train.train( self.train_dataset, self.test_dataset ) # convert the test dataset to the latent space ( test_latents, test_labels, data_dims, ) = self.subspace_method_train.convert_to_latent(self.test_dataset) # create the new train and test datasets from the latent space with updated dimensions self.train_dataset = CustomDataset(train_latents, train_labels) self.test_dataset = CustomDataset(test_latents, test_labels) self.data_dims = data_dims else: # we have already trained the subspace method, so load it from the checkpoint self.subspace_method_train = self.subspace_method( self, max_epochs=self.max_epochs, k=None, train_subspace=False ) # will return a init subspace method with the subspace models as instance attributes # modify the subspace method architecture if specified if self.layer_mods is not None: self.subspace_method_train = ( model_modifier.modify_model_architecture( self.subspace_method_train, self.layer_mods, ) ) # load checkpoint state dict self.subspace_method_train.load_ckpt(checkpoint_path) # converting the train and test datasets to the latent space ( train_latents, train_labels, data_dims, ) = self.subspace_method_train.convert_to_latent(self.train_dataset) ( test_latents, test_labels, data_dims, ) = self.subspace_method_train.convert_to_latent(self.test_dataset) # create the new train and test datasets from the latent space with updated dimensions self.train_dataset = CustomDataset(train_latents, train_labels) self.test_dataset = CustomDataset(test_latents, test_labels) self.data_dims = data_dims
[docs] def train_dataloader(self): """ Returns the dataloader for training. Returns ------- dataloader : dataloader Dataloader for training. """ return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, )
[docs] def val_dataloader(self): """ Returns the dataloader for validation. Returns ------- dataloader : dataloader Dataloader for validation. """ return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, )
[docs] class KFoldDataModule(pl.LightningDataModule): """ Custom pytorch lightning datamodule class for the different modalities with k-fold cross validation Attributes ---------- num_folds : int Total number of folds. sources : list List of source csv files. [Tabular1, Tabular2, Image] output_paths : dict Dictionary of output paths for saving the checkpoints, figures, and the losses. image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) modality_methods : dict Dictionary of methods for loading the different modalities. fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". batch_size : int Batch size (default 8). prediction_task : str Prediction type (binary, multiclass, regression). multiclass_dimensions : int Number of classes for multiclass prediction (default None). subspace_method : class Subspace method class (default None) (only for subspace methods). layer_mods : dict Dictionary of layer modifications to make to the subspace method. (default None) max_epochs : int Maximum number of epochs to train subspace methods for. (default 1000) dataset : tensor Tensor of predictive features. Created in prepare_data(). data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim]. Created in prepare_data(). train_dataset : tensor Tensor of predictive features for training. Created in setup(). test_dataset : tensor Tensor of predictive features for testing. Created in setup(). own_early_stopping_callback : pytorch_lightning.callbacks.EarlyStopping Early stopping callback class. num_workers : int Number of workers for the dataloader (default 0). own_kfold_indices : list List of indices to use for k-fold cross validation (default None). If None, the k-fold indices are randomly selected. Structure is a list of tuples of (train_indices, test_indices). Must be the same length as num_folds. kwargs : dict Dictionary of extra arguments for the subspace method class. """
[docs] def __init__( self, fusion_model, sources, output_paths, prediction_task, batch_size, num_folds, multiclass_dimensions, subspace_method=None, image_downsample_size=None, layer_mods=None, max_epochs=1000, extra_log_string_dict=None, own_early_stopping_callback=None, num_workers=0, own_kfold_indices=None, kwargs=None, ): """ Parameters ---------- fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". sources : list List of source data files: csv or torch files. output_paths : dict Dictionary of output paths for saving the checkpoints, figures, and the losses. prediction_task : str Prediction task (binary, multiclass, regression). batch_size : int Batch size. num_folds : int Total number of folds. test_size : float Fraction of data to use for testing (default 0.2). Not needed for this class for k-fold cross validation but it's here to be consistent with TrainTestDataModule. multiclass_dimensions : int Number of classes for multiclass prediction (default None). subspace_method : class Subspace method class (default None) (only for subspace methods). image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) layer_mods : dict Dictionary of layer modifications to make to the subspace method. (default None) max_epochs : int Maximum number of epochs to train subspace methods for. (default 1000) extra_log_string_dict : dict Dictionary of extra strings to add to the log. own_early_stopping_callback : pytorch_lightning.callbacks.EarlyStopping Early stopping callback class (default None). num_workers : int Number of workers for the dataloader (default 0). own_kfold_indices : list List of indices to use for k-fold cross validation (default None). If None, the k-fold indices are randomly selected. Structure is a list of tuples of (train_indices, test_indices). Must be the same length as num_folds. kwargs : dict Dictionary of extra arguments for the subspace method class. """ super().__init__() self.num_folds = num_folds # total number of folds self.sources = sources self.output_paths = output_paths self.image_downsample_size = image_downsample_size self.extra_log_string_dict = extra_log_string_dict self.modality_methods = { "tabular1": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular1, "tabular2": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular2, "img": LoadDatasets(self.sources, self.image_downsample_size).load_img, "tabular_tabular": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular_tabular, "tabular_image": LoadDatasets( self.sources, self.image_downsample_size ).load_tab_and_img, } self.prediction_task = prediction_task self.fusion_model = fusion_model self.modality_type = self.fusion_model.modality_type self.batch_size = batch_size self.subspace_method = ( subspace_method # subspace method class (only for subspace methods) ) if self.prediction_task == "multiclass": self.multiclass_dimensions = multiclass_dimensions else: self.multiclass_dimensions = None self.layer_mods = layer_mods self.max_epochs = max_epochs self.own_early_stopping_callback = own_early_stopping_callback self.num_workers = num_workers self.own_kfold_indices = own_kfold_indices self.kwargs = kwargs
[docs] def prepare_data(self): """ Loads the data with LoadDatasets class Returns ------ dataset : tensor Tensor of predictive features. data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ self.dataset, self.data_dims = self.modality_methods[self.modality_type]()
[docs] def kfold_split(self): """ Splits the dataset into k folds Returns ------ folds : list List of tuples of (train_dataset, test_dataset) """ # get the indices of the dataset indices = list(range(len(self.dataset))) # split the dataset into k folds if self.own_kfold_indices is None: kf = KFold(n_splits=self.num_folds, shuffle=True) split_kf = kf.split(indices) else: split_kf = self.own_kfold_indices folds = [] for train_indices, val_indices in split_kf: # split the dataset into train and test sets for each fold train_dataset = torch.utils.data.Subset(self.dataset, train_indices) test_dataset = torch.utils.data.Subset(self.dataset, val_indices) # append the train and test datasets to the folds list folds.append((train_dataset, test_dataset)) return folds # list of tuples of (train_dataset, test_dataset)
[docs] def setup( self, checkpoint_path=None, ): """ Splits the data into train and test sets, and runs the subspace method if specified Attributes ---------- checkpoint_path : str Path to the checkpoint file for the subspace method (default None). Returns ------ train_dataloader : dataloader Dataloader for training. val_dataloader : dataloader Dataloader for validation. """ self.folds = self.kfold_split() # get the k folds from kfold_split() function # if subspace method is specified, run the subspace method on each fold if self.subspace_method is not None: # if no checkpoint path specified, train the subspace method if checkpoint_path is None: new_folds = [] for k, fold in enumerate(self.folds): # get the train and test datasets for each fold train_dataset, test_dataset = fold # initialise the subspace method subspace_method = self.subspace_method( self, k=k, max_epochs=self.max_epochs, train_subspace=True, ) # modify the subspace method architecture if specified if self.layer_mods is not None: # if subspace method in layer_mods subspace_method = model_modifier.modify_model_architecture( subspace_method, self.layer_mods, ) # train the subspace method and convert train dataset to the latent space train_latents, train_labels = subspace_method.train( train_dataset, test_dataset ) # convert the test dataset to the latent space ( test_latents, test_labels, data_dims, ) = subspace_method.convert_to_latent(test_dataset) # make a new CustomDataset with the latent features train_dataset = CustomDataset(train_latents, train_labels) test_dataset = CustomDataset(test_latents, test_labels) new_folds.append( (train_dataset, test_dataset) ) # append to new_folds self.folds = ( new_folds # update the folds with the new train and test datasets ) self.data_dims = data_dims # update the data dimensions else: # we have already trained the subspace method, so load it from the checkpoint new_folds = [] for k, fold in enumerate(self.folds): # get the train and test datasets for each fold train_dataset, test_dataset = fold # initialise the subspace method subspace_method = self.subspace_method( self, k=k, max_epochs=self.max_epochs, train_subspace=False, # checkpoint_path=checkpoint_path, ) # modify the subspace method architecture if specified if self.layer_mods is not None: # if subspace method in layer_mods subspace_method = model_modifier.modify_model_architecture( subspace_method, self.layer_mods, ) subspace_method.load_ckpt(checkpoint_path) ( train_latents, train_labels, data_dims, ) = subspace_method.convert_to_latent(train_dataset) # convert the test dataset to the latent space ( test_latents, test_labels, data_dims, ) = subspace_method.convert_to_latent(test_dataset) # make a new CustomDataset with the latent features train_dataset = CustomDataset(train_latents, train_labels) test_dataset = CustomDataset(test_latents, test_labels) new_folds.append((train_dataset, test_dataset)) self.folds = ( new_folds # update the folds with the new train and test datasets ) self.data_dims = data_dims # update the data dimensions
[docs] def train_dataloader(self, fold_idx): """ Returns the dataloader for training. Parameters ---------- fold_idx : int Index of the fold to use. Returns ------- dataloader : dataloader Dataloader for training. """ self.train_dataset, self.test_dataset = self.folds[fold_idx] return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, )
[docs] def val_dataloader(self, fold_idx): """ Returns the dataloader for validation. Parameters ---------- fold_idx : int Index of the fold to use. Returns ------- dataloader : dataloader Dataloader for validation. """ self.train_dataset, self.test_dataset = self.folds[fold_idx] return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, )
[docs] class TrainTestGraphDataModule: """ Custom pytorch lightning datamodule class for the different modalities with graph data structure. Attributes ---------- sources : list List of source csv files. image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. modality_methods : dict Dictionary of methods for loading the different modalities. fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". test_size : float Fraction of data to use for testing (default 0.2). graph_creation_method : class Graph creation method class. graph_maker_instance : graph maker class Graph maker class instance. layer_mods : dict Dictionary of layer modifications to make to the graph maker method. dataset : tensor Tensor of predictive features. Created in prepare_data(). data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim]. Created in prepare_data(). train_idxs : list List of indices for training. Created in setup(). test_idxs : list List of indices for testing. Created in setup(). graph_data : graph data structure Graph data structure. Created in setup(). own_test_indices : list List of indices to use for testing (default None). If None, the test indices are randomly selected using the test_size parameter. """
[docs] def __init__( self, fusion_model, sources, graph_creation_method, test_size, image_downsample_size=None, layer_mods=None, extra_log_string_dict=None, own_test_indices=None, ): """ Parameters ---------- fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". sources : list List of source csv files. graph_creation_method : class Graph creation method class. test_size : float Fraction of data to use for testing (default 0.2). image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) layer_mods : dict Dictionary of layer modifications to make to the graph maker method. (default None) extra_log_string_dict : dict Dictionary of extra strings to add to the log. own_test_indices : list List of indices to use for testing (default None). If None, the test indices are randomly selected using the test_size parameter. """ super().__init__() self.sources = sources self.image_downsample_size = image_downsample_size self.extra_log_string_dict = extra_log_string_dict self.modality_methods = { "tabular1": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular1, "tabular2": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular2, "img": LoadDatasets(self.sources, self.image_downsample_size).load_img, "tabular_tabular": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular_tabular, "tabular_image": LoadDatasets( self.sources, self.image_downsample_size ).load_tab_and_img, } self.fusion_model = fusion_model self.modality_type = self.fusion_model.modality_type self.test_size = test_size self.graph_creation_method = graph_creation_method self.layer_mods = layer_mods self.own_test_indices = own_test_indices
[docs] def prepare_data(self): """ Loads the data with LoadDatasets class Returns ------ dataset : tensor Tensor of predictive features. data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim] i.e. [None, None, [100, 100, 100]] for image only (image dimensions 100 x 100 x 100) i.e. [8, 32, None] for tabular1 and tabular2 (tabular1 has 8 features, tabular2 has 32 features), and no image """ self.dataset, self.data_dims = self.modality_methods[self.modality_type]()
[docs] def setup(self): """ Gets random train and test indices, and gets the graph data structure. Returns ------ None """ # get random train and test idxs if self.own_test_indices is None: [train_dataset, test_dataset] = torch.utils.data.random_split( self.dataset, [1 - self.test_size, self.test_size] ) self.train_idxs = train_dataset.indices self.test_idxs = test_dataset.indices else: self.test_idxs = self.own_test_indices self.train_idxs = list(set(range(len(self.dataset))) - set(self.test_idxs)) # get the graph data structure self.graph_maker_instance = self.graph_creation_method(self.dataset) if self.layer_mods is not None: # modify the graph maker architecture if specified self.graph_maker_instance = model_modifier.modify_model_architecture( self.graph_maker_instance, self.layer_mods, ) self.graph_data = self.graph_maker_instance.make_graph()
[docs] def get_lightning_module(self): """ Gets the lightning module using the pytorch geometric lightning module for converting the graph data structure into a pytorch dataloader. Returns ------ lightning_module : lightning module Lightning module for converting the graph data structure into a pytorch dataloader. """ lightning_module = LightningNodeData( data=self.graph_data, input_train_nodes=self.train_idxs, input_val_nodes=self.test_idxs, input_test_nodes=self.test_idxs, input_pred_nodes=self.test_idxs, loader="full", ) return lightning_module
[docs] class KFoldGraphDataModule: """ Custom pytorch lightning datamodule class for the different modalities with graph data structure and k-fold cross validation Attributes ---------- num_folds : int Total number of folds. image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. sources : list List of source csv files. [Tabular1, Tabular2, Image] modality_methods : dict Dictionary of methods for loading the different modalities. fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". graph_creation_method : class Graph creation method class. graph_maker_instance : graph maker class Graph maker class instance. layer_mods : dict Dictionary of layer modifications to make to the graph maker method. dataset : tensor Tensor of predictive features. data_dims : list List of data dimensions [mod1_dim, mod2_dim, img_dim] folds : list List of tuples of (graph_data, train_idxs, test_idxs) """
[docs] def __init__( self, num_folds, fusion_model, sources, graph_creation_method, image_downsample_size=None, layer_mods=None, extra_log_string_dict=None, own_kfold_indices=None, ): """ Parameters ---------- num_folds : int Total number of folds. fusion_model : class Fusion model class. e.g. "TabularCrossmodalAttention". sources : list List of source csv files. graph_creation_method : class Graph creation method class. image_downsample_size : tuple Size to downsample the images to (height, width, depth) or (height, width) for 2D images. None if not downsampling. (default None) layer_mods : dict Dictionary of layer modifications to make to the graph maker method. (default None) extra_log_string_dict : dict Dictionary of extra strings to add to the log. own_kfold_indices : list List of indices to use for k-fold cross validation (default None). If None, the k-fold indices are randomly selected. Structure is a list of tuples of (train_indices, test_indices). Must be the same length as num_folds. """ super().__init__() self.num_folds = num_folds # total number of folds self.image_downsample_size = image_downsample_size self.sources = sources self.extra_log_string_dict = extra_log_string_dict self.modality_methods = { "tabular1": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular1, "tabular2": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular2, "img": LoadDatasets(self.sources, self.image_downsample_size).load_img, "tabular_tabular": LoadDatasets( self.sources, self.image_downsample_size ).load_tabular_tabular, "tabular_image": LoadDatasets( self.sources, self.image_downsample_size ).load_tab_and_img, } self.fusion_model = fusion_model self.modality_type = self.fusion_model.modality_type self.graph_creation_method = graph_creation_method self.layer_mods = layer_mods self.own_kfold_indices = own_kfold_indices
[docs] def prepare_data(self): """ Loads the data with LoadDatasets class Returns ------ None """ self.dataset, self.data_dims = self.modality_methods[self.modality_type]()
[docs] def kfold_split(self): """ Splits the dataset into k folds Returns ------ folds : list List of tuples of (train_dataset, test_dataset) """ # get the indices of the dataset indices = list(range(len(self.dataset))) # split the dataset into k folds if self.own_kfold_indices is None: kf = KFold(n_splits=self.num_folds, shuffle=True) split_kf = kf.split(indices) else: split_kf = self.own_kfold_indices folds = [] for train_indices, val_indices in split_kf: # split the dataset into train and test sets for each fold train_dataset = torch.utils.data.Subset(self.dataset, train_indices) test_dataset = torch.utils.data.Subset(self.dataset, val_indices) # append the train and test datasets to the folds list folds.append((train_dataset, test_dataset)) return folds # list of tuples of (train_dataset, test_dataset)
[docs] def setup(self): """ Gets random train and test indices, and gets the graph data structure. Returns ------ None """ self.folds = self.kfold_split() # get the k folds from kfold_split() function new_folds = [] for fold in self.folds: # get the train and test datasets for each fold train_dataset, test_dataset = fold train_idxs = train_dataset.indices # get train node idxs from kfold_split() test_idxs = test_dataset.indices # get test node idxs from kfold_split() # get the graph data structure self.graph_maker_instance = self.graph_creation_method(self.dataset) # modify the graph maker architecture if specified if self.layer_mods is not None: self.graph_maker_instance = model_modifier.modify_model_architecture( self.graph_maker_instance, self.layer_mods, ) # make the graph data structure graph_data = self.graph_maker_instance.make_graph() new_folds.append((graph_data, train_idxs, test_idxs)) self.folds = new_folds # list of tuples of (graph_data, train_idxs, test_idxs)
[docs] def get_lightning_module(self): """ Returns the lightning module using the pytorch geometric lightning module for converting the graph data structure into a pytorch dataloader. Returns ------ lightning_modules : list List of lightning modules for each fold. """ # get the normal lightning module using the pytorch geometric lightning module lightning_modules = [] for fold in self.folds: graph_data, train_idxs, test_idxs = fold lightning_module = LightningNodeData( data=graph_data, input_train_nodes=train_idxs, input_val_nodes=test_idxs, input_test_nodes=test_idxs, input_pred_nodes=test_idxs, loader="full", ) lightning_modules.append(lightning_module) return lightning_modules # list of lightning modules for each fold
[docs] def prepare_fusion_data( prediction_task, fusion_model, data_paths, output_paths, kfold=False, num_folds=None, test_size=0.2, batch_size=8, multiclass_dimensions=None, image_downsample_size=None, layer_mods=None, max_epochs=1000, checkpoint_path=None, extra_log_string_dict=None, own_early_stopping_callback=None, num_workers=0, test_indices=None, own_kfold_indices=None, ): """ Gets the data module for a specific fusion model and training protocol. Parameters ---------- prediction_task : str Prediction task (binary, multiclass, regression). fusion_model : class Fusion model class. data_paths : dict Dictionary of data paths with keys "tabular1", "tabular2", "image". output_paths : dict Dictionary of output paths with keys "checkpoints", "figures", "losses". kfold : bool Whether to use kfold cross validation (default False means train/test split). num_folds : int or None Number of folds for kfold cross validation (default None). test_size : float Fraction of data to use for testing when using train/test split (default 0.2). batch_size : int Batch size (default 8). multiclass_dimensions : int Number of classes for multiclass prediction (default None). image_downsample_size : tuple Tuple of image dimensions to downsample to (default None). e.g. (100, 100, 100) for 3D images, (100, 100) for 2D images. layer_mods : dict Dictionary of layer modifications (default None). max_epochs : int Maximum number of epochs to train subspace methods for. (default 1000) checkpoint_path : list List containing paths to call checkpoint file. Length of the list is the number of trainable subspace models in the fusion model (e.g., DAETabImgMaps requires two models to be pre-trained, so we'd pass 2 checkpoint paths in the list. (default None will result in the default lightning format). extra_log_string_dict : dict Dictionary of extra strings to add to a subspace method checkpoint file name (default None). e.g. if you're running the same model with different hyperparameters, you can add the hyperparameters. Input format {"name": "value"}. In the run name, the extra string will be added as "name_value". And a tag will be added as "name_value". Default None. own_early_stopping_callback : pytorch_lightning.callbacks.EarlyStopping Early stopping callback class (default None). num_workers : int Number of workers for the dataloader (default 0). test_indices : list or None List of indices to use for testing (default None). If None, then random split is used. own_kfold_indices : list or None List of indices to use for k-fold cross validation (default None). If None, then random split is used. Returns ------- dm : datamodule Datamodule for the specified fusion method. """ if kfold and own_early_stopping_callback is not None: raise ValueError( "Cannot use own early stopping callback with kfold cross validation yet. Working on fixing this currently (Nov 2023)" ) # Getting the data paths from the data_paths dictionary into a list data_sources = [ data_paths["tabular1"], data_paths["tabular2"], data_paths["image"], ] if not hasattr(fusion_model, "subspace_method"): fusion_model.subspace_method = None if fusion_model.fusion_type == "graph": if kfold: graph_data_module = KFoldGraphDataModule( num_folds=num_folds, fusion_model=fusion_model, sources=data_sources, graph_creation_method=fusion_model.graph_maker, image_downsample_size=image_downsample_size, layer_mods=layer_mods, extra_log_string_dict=extra_log_string_dict, own_kfold_indices=own_kfold_indices, ) else: graph_data_module = TrainTestGraphDataModule( fusion_model, sources=data_sources, graph_creation_method=fusion_model.graph_maker, test_size=test_size, image_downsample_size=image_downsample_size, layer_mods=layer_mods, extra_log_string_dict=extra_log_string_dict, own_test_indices=test_indices, ) graph_data_module.prepare_data() graph_data_module.setup() data_module = graph_data_module.get_lightning_module() if kfold: # if kfold, then we have a list of lightning modules # so we need to set the data dimensions for each lightning module for dm_instance in data_module: dm_instance.data_dims = graph_data_module.data_dims dm_instance.own_early_stopping_callback = own_early_stopping_callback dm_instance.graph_maker_instance = ( graph_data_module.graph_maker_instance ) dm_instance.output_paths = output_paths dm_instance.num_folds = num_folds dm_instance.prediction_task = prediction_task dm_instance.multiclass_dimensions = multiclass_dimensions else: data_module.data_dims = graph_data_module.data_dims data_module.own_early_stopping_callback = own_early_stopping_callback data_module.graph_maker_instance = graph_data_module.graph_maker_instance data_module.output_paths = output_paths data_module.prediction_task = prediction_task data_module.multiclass_dimensions = multiclass_dimensions else: # another other than graph fusion if kfold: data_module = KFoldDataModule( fusion_model, sources=data_sources, output_paths=output_paths, prediction_task=prediction_task, batch_size=batch_size, num_folds=num_folds, multiclass_dimensions=multiclass_dimensions, subspace_method=fusion_model.subspace_method, image_downsample_size=image_downsample_size, layer_mods=layer_mods, max_epochs=max_epochs, extra_log_string_dict=extra_log_string_dict, own_early_stopping_callback=own_early_stopping_callback, num_workers=num_workers, own_kfold_indices=own_kfold_indices, ) else: data_module = TrainTestDataModule( fusion_model, sources=data_sources, output_paths=output_paths, prediction_task=prediction_task, batch_size=batch_size, test_size=test_size, multiclass_dimensions=multiclass_dimensions, subspace_method=fusion_model.subspace_method, image_downsample_size=image_downsample_size, layer_mods=layer_mods, max_epochs=max_epochs, extra_log_string_dict=extra_log_string_dict, own_early_stopping_callback=own_early_stopping_callback, num_workers=num_workers, test_indices=test_indices, ) data_module.prepare_data() data_module.setup(checkpoint_path=checkpoint_path) return data_module