"""
Crossmodal multi-head attention model. This model uses the self attention and cross modal attention
between the two modalities: tabular and image.
"""
import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
from torch.autograd import Variable
from fusilli.utils import check_model_validity
[docs]
class CrossmodalMultiheadAttention(ParentFusionModel, nn.Module):
"""
Crossmodal multi-head attention model. This model uses the self attention and cross modal
attention between the two modalities: tabular and image.
References
----------
Golovanevsky, M., Eickhoff, C., & Singh, R. (2022). Multimodal attention-based
deep learning for Alzheimer’s disease diagnosis.
Journal of the American Medical Informatics Association, 29(12),
2014–2022. https://doi.org/10.1093/jamia/ocac168
https://github.com/rsinghlab/MADDi/blob/main/training/train_all_modalities.py
Attributes
----------
prediction_task : str
Type of prediction to be performed.
attention_embed_dim : int
Number of features of the multihead attention layer.
mod1_layers : nn.ModuleDict
Dictionary containing the layers of the first modality.
img_layers : nn.ModuleDict
Dictionary containing the layers of the image data.
fused_dim : int
Number of features of the fused layers. This is the flattened output size of the
image layers.
attention : nn.MultiheadAttention
Multihead attention layer. Takes in attention_embed_dim features as input.
img_dense : nn.Linear
Linear layer. Takes in attention_embed_dim features as input. This is the output of
the multihead attention layer.
img_to_embed_dim : nn.Linear
Linear layer. Takes in fused_dim features as input. This is the input of the
multihead attention layer.
tab_to_embed_dim : nn.Linear
Linear layer. Takes in fused_dim features as input. This is the input of the
multihead attention layer.
relu : nn.ReLU
ReLU activation function.
final_prediction : nn.Sequential
Sequential layer containing the final prediction layers.
"""
#: str: Name of the method.
method_name = "Crossmodal multi-head attention"
#: str: Type of modality.
modality_type = "tabular_image"
#: str: Type of fusion.
fusion_type = "attention"
[docs]
def __init__(self, prediction_task, data_dims, multiclass_dimensions):
"""
Parameters
----------
prediction_task : str
Type of prediction to be performed.
data_dims : list
List containing the dimensions of the data.
multiclass_dimensions : int
Number of classes in the multiclass classification task.
"""
ParentFusionModel.__init__(
self, prediction_task, data_dims, multiclass_dimensions
)
self.prediction_task = prediction_task
self.attention_embed_dim = 50
self.set_mod1_layers()
self.set_img_layers()
self.calc_fused_layers()
self.relu = nn.ReLU()
[docs]
def get_fused_dim(self):
"""
Get the number of features of the fused layers.
Returns
-------
None
"""
dummy_conv_output = Variable(torch.rand((1,) + tuple(self.img_dim)))
for layer in self.img_layers.values():
dummy_conv_output = layer(dummy_conv_output)
image_output_size = dummy_conv_output.data.view(1, -1).size(1)
self.fused_dim = image_output_size
[docs]
def calc_fused_layers(self):
"""
Calculate the fused layers.
Returns
-------
None
Raises
------
ValueError
If the number of layers in the two modalities is different.
ValueError
If dtype of the layers is not nn.ModuleDict.
ValueError
If the image dimensions are not valid. (Conv2D used for 3D img and vice versa)
"""
check_model_validity.check_dtype(self.mod1_layers, nn.ModuleDict, "mod1_layers")
check_model_validity.check_dtype(self.img_layers, nn.ModuleDict, "img_layers")
check_model_validity.check_img_dim(self.img_layers, self.img_dim, "img_layers")
if len(self.mod1_layers) != len(self.img_layers):
raise ValueError(
"The number of layers in the two modalities must be the same."
)
self.get_fused_dim()
self.attention = nn.MultiheadAttention(
embed_dim=self.attention_embed_dim, num_heads=2
)
self.img_dense = nn.Linear(self.attention_embed_dim, 1)
self.img_to_embed_dim = nn.Linear(self.fused_dim, self.attention_embed_dim)
self.tab_to_embed_dim = nn.Linear(
list(self.mod1_layers.values())[-1][0].out_features,
self.attention_embed_dim,
)
self.set_final_pred_layers(self.attention_embed_dim * 4)
[docs]
def forward(self, x1, x2):
"""
Forward pass of the model.
Parameters
----------
x1 : torch.Tensor
Input tensor for the first modality.
x2 : torch.Tensor
Input tensor for the second modality. (Image data)
Returns
-------
torch.Tensor
Output tensor.
"""
# ~~ Checks ~~
check_model_validity.check_model_input(x1)
check_model_validity.check_model_input(x2)
x_tab1 = x1.squeeze(dim=1)
x_img = x2
for i, (k, layer) in enumerate(self.mod1_layers.items()):
x_tab1 = layer(x_tab1)
x_img = self.img_layers[k](x_img)
out_img = x_img.view(x_img.size(0), -1)
out_img = self.img_to_embed_dim(out_img)
out_img = self.relu(out_img)
out_tab1 = x_tab1.view(x_tab1.size(0), -1)
out_tab1 = self.tab_to_embed_dim(out_tab1)
# self attention
img_att = self.attention(out_img, out_img, out_img)[0]
tab1_att = self.attention(out_tab1, out_tab1, out_tab1)[0]
# cross modal attention
tab1_img_att = self.attention(tab1_att, img_att, img_att)[0]
img_tab1_att = self.attention(img_att, tab1_att, tab1_att)[0]
crossmodal_att = torch.concat((tab1_img_att, img_tab1_att), dim=1)
# concatenate
merged = torch.concat((crossmodal_att, out_tab1, out_img), dim=1)
out_fuse = self.final_prediction(merged)
return out_fuse