fusilli.fusionmodels.tabularimagefusion.crossmodal_att

Crossmodal multi-head attention model. This model uses the self attention and cross modal attention between the two modalities: tabular and image.

Classes

CrossmodalMultiheadAttention(...)

Crossmodal multi-head attention model.

class CrossmodalMultiheadAttention(prediction_task, data_dims, multiclass_dimensions)[source]

Bases: ParentFusionModel, Module

Crossmodal multi-head attention model. This model uses the self attention and cross modal attention between the two modalities: tabular and image.

References

Golovanevsky, M., Eickhoff, C., & Singh, R. (2022). Multimodal attention-based deep learning for Alzheimer’s disease diagnosis. Journal of the American Medical Informatics Association, 29(12), 2014–2022. https://doi.org/10.1093/jamia/ocac168

https://github.com/rsinghlab/MADDi/blob/main/training/train_all_modalities.py

prediction_task

Type of prediction to be performed.

Type:

str

attention_embed_dim

Number of features of the multihead attention layer.

Type:

int

mod1_layers

Dictionary containing the layers of the first modality.

Type:

nn.ModuleDict

img_layers

Dictionary containing the layers of the image data.

Type:

nn.ModuleDict

fused_dim

Number of features of the fused layers. This is the flattened output size of the image layers.

Type:

int

attention

Multihead attention layer. Takes in attention_embed_dim features as input.

Type:

nn.MultiheadAttention

img_dense

Linear layer. Takes in attention_embed_dim features as input. This is the output of the multihead attention layer.

Type:

nn.Linear

img_to_embed_dim

Linear layer. Takes in fused_dim features as input. This is the input of the multihead attention layer.

Type:

nn.Linear

tab_to_embed_dim

Linear layer. Takes in fused_dim features as input. This is the input of the multihead attention layer.

Type:

nn.Linear

relu

ReLU activation function.

Type:

nn.ReLU

final_prediction

Sequential layer containing the final prediction layers.

Type:

nn.Sequential

__init__(prediction_task, data_dims, multiclass_dimensions)[source]
Parameters:
  • prediction_task (str) – Type of prediction to be performed.

  • data_dims (list) – List containing the dimensions of the data.

  • multiclass_dimensions (int) – Number of classes in the multiclass classification task.

calc_fused_layers()[source]

Calculate the fused layers.

Return type:

None

Raises:
  • ValueError – If the number of layers in the two modalities is different.

  • ValueError – If dtype of the layers is not nn.ModuleDict.

  • ValueError – If the image dimensions are not valid. (Conv2D used for 3D img and vice versa)

forward(x1, x2)[source]

Forward pass of the model.

Parameters:
  • x1 (torch.Tensor) – Input tensor for the first modality.

  • x2 (torch.Tensor) – Input tensor for the second modality. (Image data)

Returns:

Output tensor.

Return type:

torch.Tensor

fusion_type = 'attention'

Type of fusion.

Type:

str

get_fused_dim()[source]

Get the number of features of the fused layers.

Return type:

None

method_name = 'Crossmodal multi-head attention'

Name of the method.

Type:

str

modality_type = 'tabular_image'

Type of modality.

Type:

str