Source code for fusilli.fusionmodels.tabularfusion.edge_corr_gnn

"""
Edge correlation GNN model: edges are weighted by the correlation between the nodes' first tabular modality features.
"""

import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from fusilli.utils import check_model_validity


[docs] class EdgeCorrGraphMaker: """ Creates the graph data structure for the edge correlation GNN model. Attributes ---------- dataset : torch.utils.data.Dataset Dataset containing the tabular data. threshold : float How correlated the nodes need to be to be connected. Default: 0.8 """
[docs] def __init__(self, dataset): """ Parameters ---------- dataset : torch.utils.data.Dataset Dataset containing the tabular data. """ self.dataset = dataset self.threshold = 0.8 # how correlated the nodes need to be to be connected
[docs] def check_params(self): """ Checks the parameters of the model. Returns ------- None """ if self.threshold < 0 or self.threshold > 1: raise ValueError( ( "Incorrect attribute range: The threshold must be between 0 and 1, " f"inclusive. The threshold is currently: {self.threshold}" ) ) check_model_validity.check_dtype(self.threshold, float, "threshold")
[docs] def make_graph(self): """ Creates the graph data structure. Returns ------- data : torch_geometric.data.Data Graph data structure containing the tabular data. """ self.check_params() tab1 = self.dataset[:][0] tab2 = self.dataset[:][1] labels = self.dataset[:][2] num_nodes = tab1.shape[0] # correlation matrix between nodes' tab1 features corr_matrix = torch.corrcoef(tab1) - torch.eye(num_nodes) edge_indices = np.where(np.abs(corr_matrix) >= self.threshold) edge_indices = np.stack(edge_indices, axis=0) # print("Number of edges: ", edge_indices.shape[1]) x = tab2 edge_index = torch.tensor(edge_indices, dtype=torch.long) edge_attr = ( corr_matrix[edge_indices[0], edge_indices[1]] + 1 ) # add 1 to make all edge_attr positive data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=labels) return data
[docs] class EdgeCorrGNN(ParentFusionModel, nn.Module): """ Graph neural network with the edge weighting as the first tabular modality correlations and the node features as the second tabular modality features. Attributes ---------- graph_maker : function Function that creates the graph data structure: :class:`~.EdgeCorrGraphMaker` graph_conv_layers : nn.Sequential Sequential layer containing the graph convolutional layers. dropout_prob : float Dropout probability. Default: 0.5 final_prediction : nn.Sequential Sequential layer containing the final prediction layers. The final prediction layers take in 256 features. """ #: str: Name of the method. method_name = "Edge Correlation GNN" #: str: Type of modality. modality_type = "tabular_tabular" #: str: Type of fusion. fusion_type = "graph" # class: Graph maker class. graph_maker = EdgeCorrGraphMaker
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes in the multiclass classification task. """ ParentFusionModel.__init__( self, prediction_task, data_dims, multiclass_dimensions ) self.prediction_task = prediction_task self.graph_conv_layers = nn.Sequential( GCNConv(self.mod2_dim, 64), GCNConv(64, 128), GCNConv(128, 256), GCNConv(256, 256), ) self.dropout_prob = 0.5 self.calc_fused_layers()
[docs] def calc_fused_layers(self): """ Calculates the number of features after the fusion layer. Returns ------- None """ if self.dropout_prob < 0 or self.dropout_prob > 1: raise ValueError( ( "Incorrect attribute range: The dropout probability must be between," f" 0 and 1, inclusive. The dropout probability is currently: {self.dropout_prob}" ) ) check_model_validity.check_dtype( self.graph_conv_layers, nn.Sequential, "graph_conv_layers" ) check_model_validity.check_dtype(self.dropout_prob, float, "dropout_prob") # make sure the first layer takes in the number of features of the second tabular modality self.graph_conv_layers[0] = GCNConv( self.mod2_dim, self.graph_conv_layers[0].out_channels ) self.fused_dim = self.graph_conv_layers[-1].out_channels self.set_final_pred_layers(self.fused_dim)
[docs] def forward(self, x): """ Forward pass of the model. Parameters ---------- x : tuple Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes) Returns ------- out : torch.Tensor Output tensor of the model. """ # ~~ Checks ~~ # check x is a tuple of length 3 check_model_validity.check_model_input(x, tuple_flag=True, correct_length=3) x_n, edge_index, edge_attr = x for layer in self.graph_conv_layers: x_n = layer(x_n, edge_index, edge_attr) x_n = x_n.relu() x_n = F.dropout(x_n, p=self.dropout_prob, training=self.training) out = self.final_prediction(x_n) return out