fusilli.fusionmodels.tabularfusion.attention_weighted_GNNο
Attention weighted GNN model: the edge weights are the attention weights from a pre-trained MLP and the node features are the second modality.
Classes
|
MLP based on ConcatTabularData for the attention weighted GNN. |
|
Graph neural network with the edge weighting as the distances between each nodes' weighted phenotypes and the node features as the second tabular modality features. |
|
Class to make the graph structure for the attention weighted GNN. |
- class AttentionWeightMLP(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
LightningModule
MLP based on ConcatTabularData for the attention weighted GNN.
- prediction_taskο
Type of prediction to be performed.
- Type:
str
- multiclass_dimensionsο
Number of classes for multiclass classification. If not multiclass classification, this is None.
- Type:
int
- mod1_dimο
Number of features of the first modality.
- Type:
int
- mod2_dimο
Number of features of the second modality.
- Type:
int
- weighting_layersο
Module dictionary containing the weighting layers. The layers must have input size of the first modality dimension plus the second modality dimension and output size of the first modality dimension plus the second modality dimension.
- Type:
nn.ModuleDict
- fused_dimο
Number of features of the fused layers. This is the final output shape of the weighting layers.
- Type:
int
- fused_layersο
Sequential layer containing the fused layers. Calculated in the
set_fused_layers()
method.- Type:
nn.Sequential
- final_predictionο
Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. Calculated in the
set_final_pred_layers()
method.- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be performed.
data_dims (list) β List containing the dimensions of the data.
multiclass_dimensions (int) β Number of classes for multiclass classification. If not multiclass classification, this is None.
- calc_fused_layers()[source]ο
Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model.
- configure_optimizers()[source]ο
Configure the optimiser of the model.
- Returns:
optimiser β Optimiser of the model.
- Return type:
torch.optim
- create_attention_weights(x)[source]ο
Create the attention weights of the model for a given input.
- Parameters:
x (tuple) β Tuple containing the two modalities input data.
- Returns:
weights β Attention weights of the model. Final layer of the model sigmoided.
- Return type:
torch.Tensor
- forward(x)[source]ο
- Parameters:
x (tuple) β Tuple containing the two modalities input data.
- Returns:
out_pred (torch.Tensor) β Prediction output of the model.
attention_weights (torch.Tensor) β Attention weights of the model. Final layer of the model sigmoided.
- class AttentionWeightedGNN(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
ParentFusionModel
,Module
Graph neural network with the edge weighting as the distances between each nodesβ weighted phenotypes and the node features as the second tabular modality features.
This is a model inspired by method in Bintsi et al. (2023) : Multimodal brain age estimation using interpretable adaptive population-graph learning.
- prediction_taskο
Type of prediction to be performed.
- Type:
str
- graph_conv_layersο
Sequential layer containing the graph convolutional layers. By default ChebConv layers.
- Type:
nn.Sequential
- fused_dimο
Number of features of the fused layers. This is the final output shape of the graph convolutional layers.
- Type:
int
- final_predictionο
Sequential layer containing the final prediction layers. The final prediction layers
- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be performed.
data_dims (list) β List containing the dimensions of the data.
multiclass_dimensions (int) β Number of classes for multiclass classification. If not multiclass classification, this is None.
- calc_fused_layers()[source]ο
Checks the parameters of the model, sets the final prediction layers and calculates the fused layers based on any modifications to the model.
- forward(x)[source]ο
Forward pass of the model.
- Parameters:
x (tuple) β Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes)
- Returns:
List containing the output of the model.
- Return type:
list
- fusion_type = 'graph'ο
Type of fusion.
- Type:
str
- graph_makerο
alias of
AttentionWeightedGraphMaker
- method_name = 'Attention-weighted GNN'ο
Name of the method.
- Type:
str
- modality_type = 'tabular_tabular'ο
Type of modality.
- Type:
str
- class AttentionWeightedGraphMaker(dataset)[source]ο
Bases:
object
Class to make the graph structure for the attention weighted GNN.
- datasetο
Dataset containing the tabular data.
- Type:
Dataset
- early_stop_callbackο
Early stopping callback for the MLP model.
- Type:
EarlyStopping
- edge_probability_thresholdο
Probability threshold for the edges of the graph. e.g. 75 means the edges associated with the top 25% of probabilities are used.
- Type:
int
- attention_MLP_test_sizeο
Test size for the MLP model.
- Type:
float
- max_epochsο
Maximum number of epochs for the MLP model. Default -1.
- Type:
int
- AttentionWeightingMLPInstanceο
Instance of the MLP model.
- Type:
- trainerο
Trainer of the model.
- Type:
Trainer
- train_idxsο
List of the indices of the training data.
- Type:
list
- test_idxsο
List of the indices of the test data.
- Type:
list