fusilli.fusionmodels.tabularfusion.edge_corr_gnn

Edge correlation GNN model: edges are weighted by the correlation between the nodes’ first tabular modality features.

Classes

EdgeCorrGNN(prediction_task,Β data_dims,Β ...)

Graph neural network with the edge weighting as the first tabular modality correlations and the node features as the second tabular modality features.

EdgeCorrGraphMaker(dataset)

Creates the graph data structure for the edge correlation GNN model.

class EdgeCorrGNN(prediction_task, data_dims, multiclass_dimensions)[source]

Bases: ParentFusionModel, Module

Graph neural network with the edge weighting as the first tabular modality correlations and the node features as the second tabular modality features.

graph_maker

Function that creates the graph data structure: EdgeCorrGraphMaker

Type:

function

graph_conv_layers

Sequential layer containing the graph convolutional layers.

Type:

nn.Sequential

dropout_prob

Dropout probability. Default: 0.5

Type:

float

final_prediction

Sequential layer containing the final prediction layers. The final prediction layers take in 256 features.

Type:

nn.Sequential

__init__(prediction_task, data_dims, multiclass_dimensions)[source]
Parameters:
  • prediction_task (str) – Type of prediction to be performed.

  • data_dims (list) – List containing the dimensions of the data.

  • multiclass_dimensions (int) – Number of classes in the multiclass classification task.

calc_fused_layers()[source]

Calculates the number of features after the fusion layer.

Return type:

None

forward(x)[source]

Forward pass of the model.

Parameters:

x (tuple) – Tuple containing the tabular data and the graph data structure: (node features, edge indices, edge attributes)

Returns:

List containing the output of the model.

Return type:

list

fusion_type = 'graph'

Type of fusion.

Type:

str

graph_maker

alias of EdgeCorrGraphMaker

method_name = 'Edge Correlation GNN'

Name of the method.

Type:

str

modality_type = 'tabular_tabular'

Type of modality.

Type:

str

class EdgeCorrGraphMaker(dataset)[source]

Bases: object

Creates the graph data structure for the edge correlation GNN model.

dataset

Dataset containing the tabular data.

Type:

torch.utils.data.Dataset

threshold

How correlated the nodes need to be to be connected. Default: 0.8

Type:

float

__init__(dataset)[source]
Parameters:

dataset (torch.utils.data.Dataset) – Dataset containing the tabular data.

check_params()[source]

Checks the parameters of the model.

Return type:

None

make_graph()[source]

Creates the graph data structure.

Returns:

data – Graph data structure containing the tabular data.

Return type:

torch_geometric.data.Data