fusilli.fusionmodels.tabularfusion.edge_corr_gnnο
Edge correlation GNN model: edges are weighted by the correlation between the nodesβ first tabular modality features.
Classes
|
Graph neural network with the edge weighting as the first tabular modality correlations and the node features as the second tabular modality features. |
|
Creates the graph data structure for the edge correlation GNN model. |
- class EdgeCorrGNN(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
ParentFusionModel
,Module
Graph neural network with the edge weighting as the first tabular modality correlations and the node features as the second tabular modality features.
- graph_makerο
Function that creates the graph data structure:
EdgeCorrGraphMaker
- Type:
function
- graph_conv_layersο
Sequential layer containing the graph convolutional layers.
- Type:
nn.Sequential
- dropout_probο
Dropout probability. Default: 0.5
- Type:
float
- final_predictionο
Sequential layer containing the final prediction layers. The final prediction layers take in 256 features.
- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be performed.
data_dims (list) β List containing the dimensions of the data.
multiclass_dimensions (int) β Number of classes in the multiclass classification task.
- calc_fused_layers()[source]ο
Calculates the number of features after the fusion layer.
- Return type:
None
- forward(x)[source]ο
Forward pass of the model.
- Parameters:
x (tuple) β Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes)
- Returns:
List containing the output of the model.
- Return type:
list
- fusion_type = 'graph'ο
Type of fusion.
- Type:
str
- graph_makerο
alias of
EdgeCorrGraphMaker
- method_name = 'Edge Correlation GNN'ο
Name of the method.
- Type:
str
- modality_type = 'tabular_tabular'ο
Type of modality.
- Type:
str
- class EdgeCorrGraphMaker(dataset)[source]ο
Bases:
object
Creates the graph data structure for the edge correlation GNN model.
- datasetο
Dataset containing the tabular data.
- Type:
torch.utils.data.Dataset
- thresholdο
How correlated the nodes need to be to be connected. Default: 0.8
- Type:
float