fusilli.fusionmodels.tabularimagefusion.crossmodal_attο
Crossmodal multi-head attention model. This model uses the self attention and cross modal attention between the two modalities: tabular and image.
Classes
Crossmodal multi-head attention model. |
- class CrossmodalMultiheadAttention(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
ParentFusionModel,ModuleCrossmodal multi-head attention model. This model uses the self attention and cross modal attention between the two modalities: tabular and image.
References
Golovanevsky, M., Eickhoff, C., & Singh, R. (2022). Multimodal attention-based deep learning for Alzheimerβs disease diagnosis. Journal of the American Medical Informatics Association, 29(12), 2014β2022. https://doi.org/10.1093/jamia/ocac168
https://github.com/rsinghlab/MADDi/blob/main/training/train_all_modalities.py
- prediction_taskο
Type of prediction to be performed.
- Type:
str
- attention_embed_dimο
Number of features of the multihead attention layer.
- Type:
int
- mod1_layersο
Dictionary containing the layers of the first modality.
- Type:
nn.ModuleDict
- img_layersο
Dictionary containing the layers of the image data.
- Type:
nn.ModuleDict
- fused_dimο
Number of features of the fused layers. This is the flattened output size of the image layers.
- Type:
int
- attentionο
Multihead attention layer. Takes in attention_embed_dim features as input.
- Type:
nn.MultiheadAttention
- img_denseο
Linear layer. Takes in attention_embed_dim features as input. This is the output of the multihead attention layer.
- Type:
nn.Linear
- img_to_embed_dimο
Linear layer. Takes in fused_dim features as input. This is the input of the multihead attention layer.
- Type:
nn.Linear
- tab_to_embed_dimο
Linear layer. Takes in fused_dim features as input. This is the input of the multihead attention layer.
- Type:
nn.Linear
- reluο
ReLU activation function.
- Type:
nn.ReLU
- final_predictionο
Sequential layer containing the final prediction layers.
- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be performed.
data_dims (list) β List containing the dimensions of the data.
multiclass_dimensions (int) β Number of classes in the multiclass classification task.
- calc_fused_layers()[source]ο
Calculate the fused layers.
- Return type:
None
- Raises:
ValueError β If the number of layers in the two modalities is different.
ValueError β If dtype of the layers is not nn.ModuleDict.
ValueError β If the image dimensions are not valid. (Conv2D used for 3D img and vice versa)
- forward(x1, x2)[source]ο
Forward pass of the model.
- Parameters:
x1 (torch.Tensor) β Input tensor for the first modality.
x2 (torch.Tensor) β Input tensor for the second modality. (Image data)
- Returns:
Output tensor.
- Return type:
torch.Tensor
- fusion_type = 'attention'ο
Type of fusion.
- Type:
str
- method_name = 'Crossmodal multi-head attention'ο
Name of the method.
- Type:
str
- modality_type = 'tabular_image'ο
Type of modality.
- Type:
str