Note
Go to the end to download the full example code.
Fusion Model Template: Graph-based Fusion
This template is for creating your own fusion model that is graph-based. An example of a graph-based fusion model is EdgeCorrGNN.
Note
I recommend looking at How to create your own fusion model: a general template before looking at this template, as I will skip over some of the details that are covered in that template (particularly regarding documentation and idiosyncrasies of the fusion model template).
Building a graph-based fusion model is a bit different to the general template in How to create your own fusion model: a general template. The main difference is that you need to create a method that will create the graph structure from the input data.
For the EdgeCorrGNN, this is done in the EdgeCorrGraphMaker class, which is in the same .py file as the EdgeCorrGNN class.
First, let’s look at creating the graph-maker class.
Creating the Graph-Maker Class
The graph will probably be created with the PyTorch Geometric library, which is a library for creating graph-based models in PyTorch.
Let’s import the libraries that we need:
# sphinx_gallery_thumbnail_path = '_static/EdgeCorrGNN.png'
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data
from fusilli.fusionmodels.base_model import ParentFusionModel
from fusilli.utils import check_model_validity
Now let’s create the graph-maker class.
The graph-maker class must have the following methods:
__init__: This method initialises the graph-maker class. It must take atorch.utils.data.Datasetas an argument (created inTrainTestGraphDataModule.setup()orKFoldGraphDataModule.setup()).check_params: This method checks the parameters of the graph-maker class. It should raise aValueErrorif the parameters are invalid. This will check validity of any modifications made to the model as well.make_graph: This method creates the graph data structure. It must return atorch_geometric.data.Dataobject.
class TemplateGraphMaker:
def __init__(self, dataset):
self.dataset = dataset
# other attributes for the graph maker go here
def check_params(self):
# check the parameters of the graph maker here
pass
def make_graph(self):
# create the graph here with self.dataset
self.check_params()
modality_1_data = self.dataset[:][0]
modality_2_data = self.dataset[:][1]
labels = self.dataset[:][2]
# some code to create the graph to get out:
# - node attributes
# - edge attributes
# - edge indices
# replace the strings with the actual graph data
data = Data(
x="node attributes",
edge_attr="edge attributes",
edge_index="edge indices",
y="labels"
)
return data
Creating the Fusion Model Class
Now let’s create the fusion model class that will take in the graph data structure and perform the prediction.
In addition to the class-level attributes for every fusion model, a graph-based fusion model class must have a class-level attribute graph_maker that is the graph-maker class that we created above.
Very similar to the general fusion model template in How to create your own fusion model: a general template, the fusion model class must have the following methods:
__init__: initialising with input parametersprediction_task,data_dims, andmulticlass_dimensions.calc_fused_layers: checking the parameters of the fusion model if they’re modified and recalculate the layers of the fusion model where necessary.forward: the forward pass of the fusion model. Takesxas input but in this example, this is a tuple of the node features, edge indices, and edge attributes.
Note
The graph-maker class returns a torch_geometric.data.Data object, but in prepare_fusion_data(), this is converted to torch_geometric.data.lightning.LightningNodeData object, which lets you use the torch_geometric library with PyTorch Lightning.
from torch_geometric.nn import GCNConv
class TemplateGraphFusionModel(ParentFusionModel, nn.Module):
method_name = "Template Graph Fusion Model"
modality_type = "tabular_tabular"
fusion_type = "graph"
graph_maker = TemplateGraphMaker
def __init__(self, prediction_task, data_dims, multiclass_dimensions):
ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions)
self.prediction_task = prediction_task
# create some graph convolutional layers here. For example, GCNConv from PyTorch Geometric
self.graph_layers = nn.Sequential(
GCNConv(1, 64),
GCNConv(64, 128),
GCNConv(128, 256),
)
self.calc_fused_layers()
def calc_fused_layers(self):
# checks on the parameters of the fusion model go here
# calculate the final prediction layer here and the input dimension for it
self.fused_dim = 256 # for example
self.set_final_pred_layers(self.fused_dim)
def forward(self, x):
# x is a tuple of the node features, edge indices, and edge attributes
x_n, edge_index, edge_attr = x
for layer in self.graph_conv_layers:
x_n = layer(x_n, edge_index, edge_attr)
x_n = x_n.relu()
out = self.final_prediction(x_n)
# must return a list of outputs
return [
out,
]