fusilli.fusionmodels.tabularfusion.attention_weighted_GNN

Attention weighted GNN model: the edge weights are the attention weights from a pre-trained MLP and the node features are the second modality.

Classes

AttentionWeightMLP(prediction_task,Β ...)

MLP based on ConcatTabularData for the attention weighted GNN.

AttentionWeightedGNN(prediction_task,Β ...)

Graph neural network with the edge weighting as the distances between each nodes' weighted phenotypes and the node features as the second tabular modality features.

AttentionWeightedGraphMaker(dataset)

Class to make the graph structure for the attention weighted GNN.

class AttentionWeightMLP(prediction_task, data_dims, multiclass_dimensions)[source]

Bases: LightningModule

MLP based on ConcatTabularData for the attention weighted GNN.

prediction_task

Type of prediction to be performed.

Type:

str

multiclass_dimensions

Number of classes for multiclass classification. If not multiclass classification, this is None.

Type:

int

mod1_dim

Number of features of the first modality.

Type:

int

mod2_dim

Number of features of the second modality.

Type:

int

weighting_layers

Module dictionary containing the weighting layers. The layers must have input size of the first modality dimension plus the second modality dimension and output size of the first modality dimension plus the second modality dimension.

Type:

nn.ModuleDict

fused_dim

Number of features of the fused layers. This is the final output shape of the weighting layers.

Type:

int

fused_layers

Sequential layer containing the fused layers. Calculated in the set_fused_layers() method.

Type:

nn.Sequential

final_prediction

Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. Calculated in the set_final_pred_layers() method.

Type:

nn.Sequential

__init__(prediction_task, data_dims, multiclass_dimensions)[source]
Parameters:
  • prediction_task (str) – Type of prediction to be performed.

  • data_dims (list) – List containing the dimensions of the data.

  • multiclass_dimensions (int) – Number of classes for multiclass classification. If not multiclass classification, this is None.

calc_fused_layers()[source]

Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model.

configure_optimizers()[source]

Configure the optimiser of the model.

Returns:

optimiser – Optimiser of the model.

Return type:

torch.optim

create_attention_weights(x)[source]

Create the attention weights of the model for a given input.

Parameters:

x (tuple) – Tuple containing the two modalities input data.

Returns:

weights – Attention weights of the model. Final layer of the model sigmoided.

Return type:

torch.Tensor

forward(x)[source]
Parameters:

x (tuple) – Tuple containing the two modalities input data.

Returns:

  • out_pred (torch.Tensor) – Prediction output of the model.

  • attention_weights (torch.Tensor) – Attention weights of the model. Final layer of the model sigmoided.

training_step(batch, batch_idx)[source]

Training step of the model.

Parameters:
  • batch (tuple) – Tuple containing the two modalities input data and the labels.

  • batch_idx (int) – Index of the batch.

Returns:

loss – Loss of the model.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Validation step of the model.

Parameters:
  • batch (tuple) – Tuple containing the two modalities input data and the labels.

  • batch_idx (int) – Index of the batch.

Returns:

loss – Loss of the model.

Return type:

torch.Tensor

class AttentionWeightedGNN(prediction_task, data_dims, multiclass_dimensions)[source]

Bases: ParentFusionModel, Module

Graph neural network with the edge weighting as the distances between each nodes’ weighted phenotypes and the node features as the second tabular modality features.

This is a model inspired by method in Bintsi et al. (2023) : Multimodal brain age estimation using interpretable adaptive population-graph learning.

prediction_task

Type of prediction to be performed.

Type:

str

graph_conv_layers

Sequential layer containing the graph convolutional layers. By default ChebConv layers.

Type:

nn.Sequential

fused_dim

Number of features of the fused layers. This is the final output shape of the graph convolutional layers.

Type:

int

final_prediction

Sequential layer containing the final prediction layers. The final prediction layers

Type:

nn.Sequential

__init__(prediction_task, data_dims, multiclass_dimensions)[source]
Parameters:
  • prediction_task (str) – Type of prediction to be performed.

  • data_dims (list) – List containing the dimensions of the data.

  • multiclass_dimensions (int) – Number of classes for multiclass classification. If not multiclass classification, this is None.

calc_fused_layers()[source]

Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model.

forward(x)[source]

Forward pass of the model.

Parameters:

x (tuple) – Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes)

Returns:

List containing the output of the model.

Return type:

list

fusion_type = 'graph'

Type of fusion.

Type:

str

graph_maker

alias of AttentionWeightedGraphMaker

method_name = 'Attention-weighted GNN'

Name of the method.

Type:

str

modality_type = 'tabular_tabular'

Type of modality.

Type:

str

class AttentionWeightedGraphMaker(dataset)[source]

Bases: object

Class to make the graph structure for the attention weighted GNN.

dataset

Dataset containing the tabular data.

Type:

Dataset

early_stop_callback

Early stopping callback for the MLP model.

Type:

EarlyStopping

edge_probability_threshold

Probability threshold for the edges of the graph. e.g. 75 means the edges associated with the top 25% of probabilities are used.

Type:

int

attention_MLP_test_size

Test size for the MLP model.

Type:

float

max_epochs

Maximum number of epochs for the MLP model. Default -1.

Type:

int

AttentionWeightingMLPInstance

Instance of the MLP model.

Type:

AttentionWeightMLP

trainer

Trainer of the model.

Type:

Trainer

train_idxs

List of the indices of the training data.

Type:

list

test_idxs

List of the indices of the test data.

Type:

list

__init__(dataset)[source]
Parameters:

dataset (Dataset) – Dataset containing the tabular data.

check_params()[source]

Checks the parameters of the model.

make_graph()[source]

Make the graph structure for the attention weighted GNN.

Returns:

data – Data object containing the graph structure.

Return type:

Data