Modify the Fusion Models

The fusion models in fusilli can be customized by passing a dictionary of attributes into the fusilli.data.prepare_fusion_data() and fusilli.train.train_and_save_models() functions.

Examples of how to modify the models can be found in the Comparing Models section.

Below are the modifiable attributes with guidance on how they can be changed:

Note

If modifications are made to certain layers like mod1_layers or img_layers, the attribute fused_layers will be updated to ensure the first layer has the correct input features corresponding to the modified layers. Similarly, altering fused_layers will adjust the final_prediction layer’s input features accordingly.

Warning

Errors may occur if the input features of certain layer groups (e.g., mod1_layers and img_layers) are incorrect. For instance, changing mod1_layers input features to 20 while the actual number for the first tabular modality is 10 will result in a matrix multiplication error during the forward method.

Warning

If you’re using external test data, don’t forget to pass the layer modifications into your evaluation figure function (like fusilli.eval.RealsVsPreds.from_new_data()).

Constructing the Layer Modification Dictionary

To construct the dictionary:

  • First keys should be the methods mentioned below.

  • Second keys should be the attributes from the tables below.

  • Value is the intended modification for the attribute.

For instance, modifying models using the “all” key applies those changes to all fusion models unless specifically overridden by a modification to a specific fusion model. Here’s an example demonstrating this:

layer_modifications = {
    "all": {
        "mod1_layers": nn.ModuleDict(
            {
                "layer 1": nn.Sequential(
                    nn.Linear(20, 32),
                    nn.ReLU(),
                ),
                # ... (additional layer modifications)
            }
        ),
    },  # end of "all" key
    "ConcatImgMapsTabularMaps": {  # overrides modifications made to "all"
        "mod1_layers": nn.ModuleDict(
            {
                "layer 1": nn.Sequential(
                    nn.Linear(20, 100),
                    nn.ReLU(),
                ),
                # ... (additional layer modifications)
            }
        ),
    },
    # ... (additional fusion model modifications)
}

Modifiable Attributes

ActivationFusion

Attribute

Guidance

mod1_layers

nn.ModuleDict

mod2_layers

nn.ModuleDict

fused_layers

nn.Sequential

AttentionAndActivation

Attribute

Guidance

mod1_layers

nn.ModuleDict

mod2_layers

nn.ModuleDict

fused_layers

nn.Sequential

attention_reduction_ratio

int

AttentionWeightedGNN

Attribute

Guidance

graph_conv_layers

nn.Sequential of torch_geometric.nn Layers.

dropout_prob

Float between (not including) 0 and 1.

AttentionWeightedGraphMaker

Attribute

Guidance

early_stop_callback

EarlyStopping object from from lightning.pytorch.callbacks import EarlyStopping

edge_probability_threshold

Integer between 0 and 100.

attention_MLP_test_size

Float between 0 and 1.

weighting_layers

nn.ModuleDict: final layer output size must be the same as the input layer input size.

fused_layers

nn.Sequential

ConcatImgLatentTabDoubleLoss

Attribute

Guidance

latent_dim

int

encoder

nn.Sequential

decoder

nn.Sequential

custom_loss

Loss function e.g. nn.MSELoss

fused_layers

nn.Sequential


ConcatImgLatentTabDoubleTrain

Attribute

Guidance

fused_layers

nn.Sequential


concat_img_latent_tab_subspace_method

Attribute

Guidance

autoencoder.latent_dim

int

autoencoder.encoder

nn.Sequential

autoencoder.decoder

nn.Sequential


ConcatImageMapsTabularData

Attribute

Guidance

img_layers

nn.ModuleDict

fused_layers

nn.Sequential


ConcatImageMapsTabularMaps

Attribute

Guidance

mod1_layers

nn.ModuleDict

img_layers

nn.ModuleDict

fused_layers

nn.Sequential


ConcatTabularData

Attribute

Guidance

fused_layers

nn.Sequential


ConcatTabularFeatureMaps

Attribute

Guidance

mod1_layers

nn.ModuleDict

mod2_layers

nn.ModuleDict

fused_layers

nn.Sequential


CrossmodalMultiheadAttention

Attribute

Guidance

attention_embed_dim

int

mod1_layers

nn.ModuleDict

img_layers

nn.ModuleDict


DAETabImgMaps

Attribute

Guidance

fusion_layers

nn.Sequential


denoising_autoencoder_subspace_method

Attribute

Guidance

autoencoder.latent_dim

int

autoencoder.upsampler

nn.Sequential

autoencoder.downsampler

nn.Sequential

img_unimodal.img_layers

  • nn.ModuleDict

  • Overrides modification of img_layers made to “all”

img_unimodal.fused_layers

nn.Sequential


EdgeCorrGNN

Attribute

Guidance

graph_conv_layers

  • nn.Sequential of torch_geometric.nn.GCNConv Layers.

  • The first layer’s input features should be the number of the second tabular modality’s features, but if not then this is corrected.

dropout_prob

Float between (not including) 0 and 1.


EdgeCorrGraphMaker

Attribute

Guidance

threshold

Float between (not including) 0 and 1.


ImageChannelWiseMultiAttention

Attribute

Guidance

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

  • Must have same number of layers as img_layers

img_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

  • Must have same number of layers as mod1_layers

fused_layers

nn.Sequential


ImageDecision

Attribute

Guidance

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

img_layers

  • nn.ModuleDict

  • Overrides modification of img_layers made to “all”

fusion_operation

Function (such as mean, median, etc.). Should act on the 1st dimension.


ImgUnimodal

Attribute

Guidance

img_layers

  • nn.ModuleDict

  • Overrides modification of img_layers made to “all”

fused_layers

nn.Sequential


MCVAE_tab

Attribute

Guidance

latent_space_layers

  • nn.ModuleDict

  • Input channels of first layer should be the latent space size but this is also ensured in calc_fused_layers()

fused_layers

nn.Sequential


MCVAESubspaceMethod

Attribute

Guidance

num_latent_dims

int


TabularCrossmodalMultiheadAttention

Attribute

Guidance

attention_embed_dim

int

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

  • Must have same number of layers as mod2_layers

mod2_layers


Tabular1Unimodal

Attribute

Guidance

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

fused_layers

nn.Sequential


Tabular2Unimodal

Attribute

Guidance

mod2_layers

  • nn.ModuleDict

  • Overrides modification of mod2_layers made to “all”

fused_layers

nn.Sequential


TabularChannelWiseMultiAttention

Attribute

Guidance

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

  • Must have same number of layers as mod2_layers

mod2_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

  • Must have same number of layers as mod1_layers

fused_layers

nn.Sequential


TabularDecision

Attribute

Guidance

mod1_layers

  • nn.ModuleDict

  • Overrides modification of mod1_layers made to “all”

mod2_layers

  • nn.ModuleDict

  • Overrides modification of mod2_layers made to “all”

fusion_operation

Function (such as mean, median, etc.). Should act on the 1st dimension.