.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "contributing_examples/template_graph_fusion.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_contributing_examples_template_graph_fusion.py: Fusion Model Template: Graph-based Fusion ========================================= This template is for creating your own fusion model that is graph-based. An example of a graph-based fusion model is :class:`~fusilli.fusionmodels.tabularfusion.edge_corr_gnn.EdgeCorrGNN`. .. note:: I recommend looking at :ref:`how_to_contribute_a_template_other_fusion` before looking at this template, as I will skip over some of the details that are covered in that template (particularly regarding documentation and idiosyncrasies of the fusion model template). Building a graph-based fusion model is a bit different to the general template in :ref:`how_to_contribute_a_template_other_fusion`. The main difference is that you need to create a method that will create the graph structure from the input data. For the :class:`~fusilli.fusionmodels.tabularfusion.edge_corr_gnn.EdgeCorrGNN`, this is done in the :class:`~fusilli.fusionmodels.tabularfusion.edge_corr_gnn.EdgeCorrGraphMaker` class, which is in the same ``.py`` file as the :class:`~fusilli.fusionmodels.tabularfusion.edge_corr_gnn.EdgeCorrGNN` class. First, let's look at creating the graph-maker class. .. GENERATED FROM PYTHON SOURCE LINES 21-27 Creating the Graph-Maker Class ------------------------------ The graph will probably be created with the ``PyTorch Geometric`` library, which is a library for creating graph-based models in PyTorch. Let's import the libraries that we need: .. GENERATED FROM PYTHON SOURCE LINES 27-40 .. code-block:: Python # sphinx_gallery_thumbnail_path = '_static/EdgeCorrGNN.png' import numpy as np import torch import torch.nn as nn from torch_geometric.data import Data from fusilli.fusionmodels.base_model import ParentFusionModel from fusilli.utils import check_model_validity .. GENERATED FROM PYTHON SOURCE LINES 41-50 Now let's create the graph-maker class. The graph-maker class must have the following methods: - ``__init__``: This method initialises the graph-maker class. It must take a ``torch.utils.data.Dataset`` as an argument (created in :meth:`.TrainTestGraphDataModule.setup` or :meth:`.KFoldGraphDataModule.setup`). - ``check_params``: This method checks the parameters of the graph-maker class. It should raise a ``ValueError`` if the parameters are invalid. This will check validity of any modifications made to the model as well. - ``make_graph``: This method creates the graph data structure. It must return a ``torch_geometric.data.Data`` object. .. GENERATED FROM PYTHON SOURCE LINES 50-89 .. code-block:: Python class TemplateGraphMaker: def __init__(self, dataset): self.dataset = dataset # other attributes for the graph maker go here def check_params(self): # check the parameters of the graph maker here pass def make_graph(self): # create the graph here with self.dataset self.check_params() modality_1_data = self.dataset[:][0] modality_2_data = self.dataset[:][1] labels = self.dataset[:][2] # some code to create the graph to get out: # - node attributes # - edge attributes # - edge indices # replace the strings with the actual graph data data = Data( x="node attributes", edge_attr="edge attributes", edge_index="edge indices", y="labels" ) return data .. GENERATED FROM PYTHON SOURCE LINES 90-107 Creating the Fusion Model Class ------------------------------- Now let's create the fusion model class that will take in the graph data structure and perform the prediction. In addition to the class-level attributes for every fusion model, a graph-based fusion model class **must** have a class-level attribute ``graph_maker`` that is the graph-maker class that we created above. Very similar to the general fusion model template in :ref:`how_to_contribute_a_template_other_fusion`, the fusion model class must have the following methods: - ``__init__``: initialising with input parameters ``prediction_task``, ``data_dims``, and ``multiclass_dimensions``. - ``calc_fused_layers``: checking the parameters of the fusion model if they're modified and recalculate the layers of the fusion model where necessary. - ``forward``: the forward pass of the fusion model. Takes ``x`` as input but in this example, this is a tuple of the node features, edge indices, and edge attributes. .. note:: The graph-maker class returns a ``torch_geometric.data.Data`` object, but in :func:`~.prepare_fusion_data`, this is converted to ``torch_geometric.data.lightning.LightningNodeData`` object, which lets you use the ``torch_geometric`` library with PyTorch Lightning. .. GENERATED FROM PYTHON SOURCE LINES 107-156 .. code-block:: Python from torch_geometric.nn import GCNConv class TemplateGraphFusionModel(ParentFusionModel, nn.Module): method_name = "Template Graph Fusion Model" modality_type = "tabular_tabular" fusion_type = "graph" graph_maker = TemplateGraphMaker def __init__(self, prediction_task, data_dims, multiclass_dimensions): ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions) self.prediction_task = prediction_task # create some graph convolutional layers here. For example, GCNConv from PyTorch Geometric self.graph_layers = nn.Sequential( GCNConv(1, 64), GCNConv(64, 128), GCNConv(128, 256), ) self.calc_fused_layers() def calc_fused_layers(self): # checks on the parameters of the fusion model go here # calculate the final prediction layer here and the input dimension for it self.fused_dim = 256 # for example self.set_final_pred_layers(self.fused_dim) def forward(self, x): # x is a tuple of the node features, edge indices, and edge attributes x_n, edge_index, edge_attr = x for layer in self.graph_conv_layers: x_n = layer(x_n, edge_index, edge_attr) x_n = x_n.relu() out = self.final_prediction(x_n) # must return a list of outputs return [ out, ] .. _sphx_glr_download_contributing_examples_template_graph_fusion.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: template_graph_fusion.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: template_graph_fusion.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: template_graph_fusion.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_