Source code for fusilli.fusionmodels.tabularfusion.attention_weighted_GNN

"""
Attention weighted GNN model: the edge weights are the attention weights from a pre-trained MLP and the node features are the second modality.
"""

import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, ChebConv
import torch.nn.functional as F
from fusilli.utils import check_model_validity
import lightning.pytorch as pl
from fusilli.utils.training_utils import (
    get_checkpoint_filenames_for_subspace_models,
    init_trainer,
)
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch import Trainer


[docs] class AttentionWeightMLP(pl.LightningModule): """ MLP based on ConcatTabularData for the attention weighted GNN. Attributes ---------- prediction_task : str Type of prediction to be performed. multiclass_dimensions : int Number of classes for multiclass classification. If not multiclass classification, this is None. mod1_dim : int Number of features of the first modality. mod2_dim : int Number of features of the second modality. weighting_layers : nn.ModuleDict Module dictionary containing the weighting layers. The layers must have input size of the first modality dimension plus the second modality dimension and output size of the first modality dimension plus the second modality dimension. fused_dim : int Number of features of the fused layers. This is the final output shape of the weighting layers. fused_layers : nn.Sequential Sequential layer containing the fused layers. Calculated in the :meth:`~ParentFusionModel.set_fused_layers` method. final_prediction : nn.Sequential Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. Calculated in the :meth:`~ParentFusionModel.set_final_pred_layers` method. """
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes for multiclass classification. If not multiclass classification, this is None. """ super().__init__() self.prediction_task = prediction_task self.multiclass_dimensions = multiclass_dimensions self.mod1_dim = data_dims[0] self.mod2_dim = data_dims[1] self.weighting_layers = nn.ModuleDict( { "Layer 1": nn.Sequential( nn.Linear(self.mod1_dim + self.mod2_dim, 256), nn.ReLU() ), "Layer 2": nn.Sequential(nn.Linear(256, 128), nn.ReLU()), "Layer 3": nn.Sequential(nn.Linear(128, 128), nn.ReLU()), "Layer 4": nn.Sequential(nn.Linear(128, 256), nn.ReLU()), "Layer 5": nn.Sequential( nn.Linear(256, self.mod1_dim + self.mod2_dim), nn.ReLU() ), } ) self.fused_dim = self.mod1_dim + self.mod2_dim ParentFusionModel.set_fused_layers(self, self.fused_dim) self.calc_fused_layers()
[docs] def calc_fused_layers(self): """ Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model. """ check_model_validity.check_dtype( self.weighting_layers, nn.ModuleDict, "weighting_layers" ) self.fused_dim = self.mod1_dim + self.mod2_dim # check final layer output size is the same as the fused dim final_weighting_layer = self.weighting_layers[ list(self.weighting_layers.keys())[-1] ][0] if final_weighting_layer.out_features != self.fused_dim: raise ValueError( ( "Incorrect attribute range: The final weighting_layer layer must have an output size of" f" {self.fused_dim} (the same as the input). The final weighting layer output size is currently: {final_weighting_layer.out_features}" ) ) self.fused_layers, out_dim = check_model_validity.check_fused_layers( self.fused_layers, self.fused_dim ) ParentFusionModel.set_final_pred_layers(self, out_dim)
[docs] def forward(self, x1, x2): """ Parameters ---------- x: tuple Tuple containing the two modalities input data. Returns ------- out_pred: torch.Tensor Prediction output of the model. attention_weights: torch.Tensor Attention weights of the model. Final layer of the model sigmoided. """ x = torch.cat((x1, x2), dim=1).to(torch.float32) for layer in self.weighting_layers.values(): x = layer(x) attention_weights = torch.sigmoid(x) out_fused_layers = self.fused_layers(x) out_pred = self.final_prediction(out_fused_layers) return out_pred, attention_weights
[docs] def training_step(self, batch, batch_idx): """ Training step of the model. Parameters ---------- batch: tuple Tuple containing the two modalities input data and the labels. batch_idx: int Index of the batch. Returns ------- loss: torch.Tensor Loss of the model. """ x1, x2, y = batch y_hat, weights = self.forward(x1, x2) if self.prediction_task == "multiclass": # turn the labels into one hot vectors y = F.one_hot(y, num_classes=self.multiclass_dimensions).to(torch.float32) loss = F.mse_loss(y_hat.squeeze(), y.to(torch.float32).squeeze()) self.log("train_loss", loss, logger=None) return loss
[docs] def validation_step(self, batch, batch_idx): """ Validation step of the model. Parameters ---------- batch: tuple Tuple containing the two modalities input data and the labels. batch_idx: int Index of the batch. Returns ------- loss: torch.Tensor Loss of the model. """ x1, x2, y = batch y_hat, weights = self.forward(x1, x2) if self.prediction_task == "multiclass": # turn the labels into one hot vectors y = F.one_hot(y, num_classes=self.multiclass_dimensions).to(torch.float32) loss = F.mse_loss(y_hat.squeeze(), y.to(torch.float32).squeeze()) self.log("val_loss", loss, logger=None) return loss
[docs] def configure_optimizers(self): """ Configure the optimiser of the model. Returns ------- optimiser: torch.optim Optimiser of the model. """ optimiser = torch.optim.Adam(self.parameters(), lr=1e-3) return optimiser
[docs] def create_attention_weights(self, x1, x2): """ Create the attention weights of the model for a given input. Parameters ---------- x1: torch.Tensor First modality input data. x2: torch.Tensor Second modality input data. Returns ------- weights: torch.Tensor Attention weights of the model. Final layer of the model sigmoided. """ preds, weights = self.forward(x1, x2) return weights
[docs] class AttentionWeightedGraphMaker: """ Class to make the graph structure for the attention weighted GNN. Attributes ---------- dataset: Dataset Dataset containing the tabular data. early_stop_callback: EarlyStopping Early stopping callback for the MLP model. edge_probability_threshold: int Probability threshold for the edges of the graph. e.g. 75 means the edges associated with the top 25% of probabilities are used. attention_MLP_test_size: float Test size for the MLP model. max_epochs: int Maximum number of epochs for the MLP model. Default -1. AttentionWeightingMLPInstance: AttentionWeightMLP Instance of the MLP model. trainer: Trainer Trainer of the model. train_idxs: list List of the indices of the training data. test_idxs: list List of the indices of the test data. """
[docs] def __init__(self, dataset): """ Parameters ---------- dataset: Dataset Dataset containing the tabular data. """ self.dataset = dataset self.early_stop_callback = EarlyStopping( monitor="val_loss", min_delta=0.00, patience=15, verbose=False, mode="min", ) self.edge_probability_threshold = 75 self.attention_MLP_test_size = 0.2 # initialise MLP data_dims = [self.dataset[:][0].shape[1], self.dataset[:][1].shape[1]] if torch.is_floating_point(self.dataset[:][2][0]): prediction_task = "regression" multiclass_dim = None else: if len(np.unique(self.dataset[:][2])) == 2: prediction_task = "binary" multiclass_dim = None else: prediction_task = "multiclass" multiclass_dim = len(np.unique(self.dataset[:][2])) self.AttentionWeightingMLPInstance = AttentionWeightMLP( prediction_task, data_dims, multiclass_dim ) self.max_epochs = -1
[docs] def check_params(self): """ Checks the parameters of the model. """ # check the distance threshold percentage is an int between 0 and 100 check_model_validity.check_dtype( self.edge_probability_threshold, int, "edge_probability_threshold" ) if ( self.edge_probability_threshold <= 0 or self.edge_probability_threshold > 100 ): raise ValueError( ( "Incorrect attribute range: The distance_threshold_percentage must be between 0 and 100, " f"inclusive. The threshold is currently: {self.edge_probability_threshold}" ) ) # check early stopping is an EarlyStopping object check_model_validity.check_dtype( self.early_stop_callback, EarlyStopping, "early_stop_callback" ) # check attention MLP test size is a float between 0 and 1 if self.attention_MLP_test_size <= 0 or self.attention_MLP_test_size > 1: raise ValueError( ( "Incorrect attribute range: The attention_MLP_test_size must be between 0 and 1, " f"inclusive. The threshold is currently: {self.attention_MLP_test_size}" ) ) check_model_validity.check_dtype( self.attention_MLP_test_size, float, "attention_MLP_test_size" )
[docs] def make_graph(self): """ Make the graph structure for the attention weighted GNN. Returns ------- data: Data Data object containing the graph structure. """ # get out the tabular data all_labels = self.dataset[:][2] tab1 = self.dataset[:][0] tab2 = self.dataset[:][1] labels = self.dataset[:][2] # split the dataset [train_dataset, test_dataset] = torch.utils.data.random_split( self.dataset, [1 - self.attention_MLP_test_size, self.attention_MLP_test_size], ) self.train_idxs = train_dataset.indices self.test_idxs = test_dataset.indices # get the dataset tab1_train = train_dataset[:][0] tab2_train = train_dataset[:][1] labels_train = train_dataset[:][2] tab1_test = test_dataset[:][0] tab2_test = test_dataset[:][1] labels_test = test_dataset[:][2] data_dims = [tab1_train.shape[1], tab2_train.shape[1]] num_nodes = all_labels.shape[0] # number of nodes/subjects # set up a pytorch trainer train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False) val_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) callbacks_list = [self.early_stop_callback] self.trainer = Trainer( num_sanity_val_steps=0, callbacks=callbacks_list, log_every_n_steps=2, logger=False, enable_checkpointing=False, max_epochs=self.max_epochs, ) # fit the MLP model self.trainer.fit( self.AttentionWeightingMLPInstance, train_dataloader, val_dataloader ) self.trainer.validate(self.AttentionWeightingMLPInstance, val_dataloader) # get out the train attention weights train_attention_weights = ( self.AttentionWeightingMLPInstance.create_attention_weights( train_dataset[:][0], train_dataset[:][1] ) ) # get out the validation attention weights val_attention_weights = ( self.AttentionWeightingMLPInstance.create_attention_weights( test_dataset[:][0], test_dataset[:][1] ) ) # normalise the attention weights train_attention_weights = train_attention_weights / torch.sum( train_attention_weights ) val_attention_weights = val_attention_weights / torch.sum(val_attention_weights) # make the weighted phenotypes: multiple data by attention weights # concatenate tab1 and tab2 all_tab_train = torch.cat((tab1_train, tab2_train), dim=1) all_tab_val = torch.cat((tab1_test, tab2_test), dim=1) train_weighted_phenotypes = all_tab_train * train_attention_weights val_weighted_phenotypes = all_tab_val * val_attention_weights # concatenate the weighted phenotypes all_weighted_phenotypes = torch.cat( (train_weighted_phenotypes, val_weighted_phenotypes), dim=0 ) # get probability of each edge from weighted phenotypes distances = torch.cdist(all_weighted_phenotypes, all_weighted_phenotypes) ** 2 # normalise to go between 0 and 1 distances = distances / torch.max(distances) distances = distances.detach().numpy() probs = np.exp(-distances) # take away the identity probs = probs - np.eye(probs.shape[0]) top_percentage = np.percentile(probs, self.edge_probability_threshold) edge_indices = np.where(probs > top_percentage) edge_indices = np.stack(edge_indices, axis=0) # make the node features the second modality (train and val) node_features = torch.cat((tab2_train, tab2_test), dim=0) # construct the graph structure edge_index = torch.tensor(edge_indices, dtype=torch.long) edge_attr = torch.tensor(distances[edge_indices[0], edge_indices[1]]) data = Data( x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=all_labels ) return data
[docs] class AttentionWeightedGNN(ParentFusionModel, nn.Module): """ Graph neural network with the edge weighting as the distances between each nodes' weighted phenotypes and the node features as the second tabular modality features. This is a model inspired by method in `Bintsi et al. (2023) <https://arxiv.org/abs/2307.04639>`_ : *Multimodal brain age estimation using interpretable adaptive population-graph learning*. Attributes ---------- prediction_task : str Type of prediction to be performed. graph_conv_layers : nn.Sequential Sequential layer containing the graph convolutional layers. By default ChebConv layers. fused_dim : int Number of features of the fused layers. This is the final output shape of the graph convolutional layers. final_prediction : nn.Sequential Sequential layer containing the final prediction layers. The final prediction layers """ #: str: Name of the method. method_name = "Attention-weighted GNN" #: str: Type of modality. modality_type = "tabular_tabular" #: str: Type of fusion. fusion_type = "graph" # class: Graph maker class. graph_maker = AttentionWeightedGraphMaker
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes for multiclass classification. If not multiclass classification, this is None. """ ParentFusionModel.__init__( self, prediction_task, data_dims, multiclass_dimensions ) self.prediction_task = prediction_task self.graph_conv_layers = nn.Sequential( ChebConv(self.mod2_dim, 64, K=3), ChebConv(64, 128, K=3), ChebConv(128, 256, K=3), ChebConv(256, 256, K=3), ) self.dropout_prob = 0.2 self.calc_fused_layers()
[docs] def calc_fused_layers(self): """ Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model. """ # check graph layers are sequential check_model_validity.check_dtype( self.graph_conv_layers, nn.Sequential, "graph_conv_layers" ) check_model_validity.check_dtype(self.dropout_prob, float, "dropout_prob") # check dropout probability is between 0 and 1 if self.dropout_prob < 0 or self.dropout_prob > 1: raise ValueError( ( f"Incorrect attribute range: The dropout probability must be between, 0 and 1, inclusive. The dropout probability is currently: {self.dropout_prob}" ) ) self.fused_dim = self.graph_conv_layers[-1].out_channels self.set_final_pred_layers(self.fused_dim)
[docs] def forward(self, x): """ Forward pass of the model. Parameters ---------- x : tuple Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes) Returns ------- out : torch.Tensor Prediction output of the model. """ check_model_validity.check_model_input(x, tuple_flag=True, correct_length=3) x_n, edge_index, edge_attr = x for layer in self.graph_conv_layers: x_n = layer(x_n, edge_index, edge_attr) x_n = x_n.relu() x_n = F.dropout(x_n, p=self.dropout_prob, training=self.training) out = self.final_prediction(x_n) return out