Fusion Model Template: Graph-based Fusion

This template is for creating your own fusion model that is graph-based. An example of a graph-based fusion model is EdgeCorrGNN.

Note

I recommend looking at How to create your own fusion model: a general template before looking at this template, as I will skip over some of the details that are covered in that template (particularly regarding documentation and idiosyncrasies of the fusion model template).

Building a graph-based fusion model is a bit different to the general template in How to create your own fusion model: a general template. The main difference is that you need to create a method that will create the graph structure from the input data.

For the EdgeCorrGNN, this is done in the EdgeCorrGraphMaker class, which is in the same .py file as the EdgeCorrGNN class.

First, let’s look at creating the graph-maker class.

Creating the Graph-Maker Class

The graph will probably be created with the PyTorch Geometric library, which is a library for creating graph-based models in PyTorch.

Let’s import the libraries that we need:

# sphinx_gallery_thumbnail_path = '_static/EdgeCorrGNN.png'

import numpy as np
import torch
import torch.nn as nn

from torch_geometric.data import Data

from fusilli.fusionmodels.base_model import ParentFusionModel
from fusilli.utils import check_model_validity

Now let’s create the graph-maker class.

The graph-maker class must have the following methods:

  • __init__: This method initialises the graph-maker class. It must take a torch.utils.data.Dataset as an argument (created in TrainTestGraphDataModule.setup() or KFoldGraphDataModule.setup()).

  • check_params: This method checks the parameters of the graph-maker class. It should raise a ValueError if the parameters are invalid. This will check validity of any modifications made to the model as well.

  • make_graph: This method creates the graph data structure. It must return a torch_geometric.data.Data object.

class TemplateGraphMaker:
    def __init__(self, dataset):
        self.dataset = dataset

        # other attributes for the graph maker go here

    def check_params(self):
        # check the parameters of the graph maker here

        pass

    def make_graph(self):
        # create the graph here with self.dataset

        self.check_params()

        modality_1_data = self.dataset[:][0]
        modality_2_data = self.dataset[:][1]
        labels = self.dataset[:][2]

        # some code to create the graph to get out:
        # - node attributes
        # - edge attributes
        # - edge indices

        # replace the strings with the actual graph data

        data = Data(
            x="node attributes",
            edge_attr="edge attributes",
            edge_index="edge indices",
            y="labels"
        )

        return data

Creating the Fusion Model Class

Now let’s create the fusion model class that will take in the graph data structure and perform the prediction.

In addition to the class-level attributes for every fusion model, a graph-based fusion model class must have a class-level attribute graph_maker that is the graph-maker class that we created above.

Very similar to the general fusion model template in How to create your own fusion model: a general template, the fusion model class must have the following methods:

  • __init__: initialising with input parameters prediction_task, data_dims, and multiclass_dimensions.

  • calc_fused_layers: checking the parameters of the fusion model if they’re modified and recalculate the layers of the fusion model where necessary.

  • forward: the forward pass of the fusion model. Takes x as input but in this example, this is a tuple of the node features, edge indices, and edge attributes.

Note

The graph-maker class returns a torch_geometric.data.Data object, but in prepare_fusion_data(), this is converted to torch_geometric.data.lightning.LightningNodeData object, which lets you use the torch_geometric library with PyTorch Lightning.

from torch_geometric.nn import GCNConv


class TemplateGraphFusionModel(ParentFusionModel, nn.Module):
    method_name = "Template Graph Fusion Model"
    modality_type = "tabular_tabular"
    fusion_type = "graph"

    graph_maker = TemplateGraphMaker

    def __init__(self, prediction_task, data_dims, multiclass_dimensions):
        ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions)

        self.prediction_task = prediction_task

        # create some graph convolutional layers here. For example, GCNConv from PyTorch Geometric
        self.graph_layers = nn.Sequential(
            GCNConv(1, 64),
            GCNConv(64, 128),
            GCNConv(128, 256),
        )

        self.calc_fused_layers()

    def calc_fused_layers(self):
        # checks on the parameters of the fusion model go here

        # calculate the final prediction layer here and the input dimension for it

        self.fused_dim = 256  # for example

        self.set_final_pred_layers(self.fused_dim)

    def forward(self, x):
        # x is a tuple of the node features, edge indices, and edge attributes
        x_n, edge_index, edge_attr = x

        for layer in self.graph_conv_layers:
            x_n = layer(x_n, edge_index, edge_attr)
            x_n = x_n.relu()

        out = self.final_prediction(x_n)

        # must return a list of outputs

        return [
            out,
        ]

Gallery generated by Sphinx-Gallery