Source code for fusilli.fusionmodels.tabularfusion.attention_and_activation

"""
Using activation functions to fuse tabular data, with self-attention on the second tabular modality.
"""

import torch.nn as nn

from fusilli.fusionmodels.base_model import ParentFusionModel
import torch

from fusilli.utils import check_model_validity


[docs] class AttentionAndSelfActivation(ParentFusionModel, nn.Module): """ Applies an attention mechanism on the second tabular modality features and performs an element wise product of the feature maps of the two tabular modalities, tanh activation function and sigmoid activation function. Afterwards the the first tabular modality feature map is concatenated with the fused feature map. Attributes ---------- prediction_task : str Type of prediction to be performed. mod1_layers : nn.ModuleDict Dictionary containing the layers of the first modality. Calculated in the :meth:`~ParentFusionModel.set_mod1_layers` method. mod2_layers : nn.ModuleDict Dictionary containing the layers of the second modality. Calculated in the :meth:`~ParentFusionModel.set_mod2_layers` method. fused_dim : int Number of features of the fused layers. In this method, it's the size of the tabular 1 layers output plus the size of the tabular 2 layers output. fused_layers : nn.Sequential Sequential layer containing the fused layers. Calculated in the :meth:`~AttentionAndActivation.calc_fused_layers` method. final_prediction : nn.Sequential Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. Calculated in the :meth:`~AttentionAndActivation.calc_fused_layers` method. attention_reduction_ratio : int Reduction ratio of the channel attention module. """ #: str: Name of the method. method_name = "Activation function and tabular self-attention" #: str: Type of modality. modality_type = "tabular_tabular" #: str: Type of fusion. fusion_type = "operation"
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes in the multiclass classification task. """ ParentFusionModel.__init__( self, prediction_task, data_dims, multiclass_dimensions ) self.prediction_task = prediction_task self.attention_reduction_ratio = 16 self.set_mod1_layers() self.set_mod2_layers() self.get_fused_dim() self.set_fused_layers(self.fused_dim) self.calc_fused_layers()
[docs] def get_fused_dim(self): """ Get the number of features of the fused layers. Assuming mod1_layers and mod2_layers output the same dimension. """ mod1_output_dim = list(self.mod1_layers.values())[-1][0].out_features mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features # New fused dimension is the sum of mod1 and mod2 output dimensions self.fused_dim = mod1_output_dim + mod2_output_dim
[docs] def calc_fused_layers(self): """ Calculate the fused layers. Returns ------- None """ # ~~ Checks ~~ check_model_validity.check_dtype(self.mod1_layers, nn.ModuleDict, "mod1_layers") check_model_validity.check_dtype(self.mod2_layers, nn.ModuleDict, "mod2_layers") # check that the output dimensions of the modality layers are the same mod1_output_dim = list(self.mod1_layers.values())[-1][0].out_features mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features if mod1_output_dim != mod2_output_dim: raise UserWarning( "The number of output features of mod1_layers and mod2_layers must be the same for ActivationandSelfAttention. Please change the final layers in the modality layers to have the same number of output features as each other." ) self.get_fused_dim() self.fused_layers, out_dim = check_model_validity.check_fused_layers( self.fused_layers, self.fused_dim ) # setting final prediction layers with final out features of fused layers self.set_final_pred_layers(out_dim)
[docs] def forward(self, x1, x2): """ Forward pass of the model. Parameters ---------- x1 : torch.Tensor Input tensor of the first modality. x2 : torch.Tensor Input tensor of the second modality. Returns ------- torch.Tensor Output tensor. """ # ~~ Checks ~~ check_model_validity.check_model_input(x1) check_model_validity.check_model_input(x2) x_tab1 = x1 x_tab2 = x2 num_channels = x_tab2.size(1) # Channel attention channel_attention = ChannelAttentionModule( num_features=num_channels, reduction_ratio=self.attention_reduction_ratio ) x_tab2 = channel_attention(x_tab2) for layer in self.mod1_layers.values(): x_tab1 = layer(x_tab1) for layer in self.mod2_layers.values(): x_tab2 = layer(x_tab2) x_tab1 = torch.squeeze(x_tab1, 1) x_tab2 = torch.squeeze(x_tab2, 1) out_fuse = torch.mul(x_tab1, x_tab2) out_fuse = torch.tanh(out_fuse) out_fuse = torch.sigmoid(out_fuse) out_fuse = torch.cat((out_fuse, x_tab1), dim=1) out_fuse = self.fused_layers(out_fuse) out = self.final_prediction(out_fuse) return out
[docs] class ChannelAttentionModule(nn.Module): """ Channel attention module. Attributes ---------- fc1 : nn.Linear First fully connected layer. relu : nn.ReLU ReLU activation function. fc2 : nn.Linear Second fully connected layer. sigmoid : nn.Sigmoid Sigmoid activation function. """
[docs] def __init__(self, num_features, reduction_ratio=16): """ Parameters ---------- num_features: int Number of features of the input tensor. reduction_ratio: int Reduction ratio of the channel attention module. """ super(ChannelAttentionModule, self).__init__() if num_features // reduction_ratio < 1: raise UserWarning( "first tabular modality dimensions // attention_reduction_ratio < 1. This will cause an error in the model." ) self.fc1 = nn.Linear(num_features, num_features // reduction_ratio, bias=False) self.relu = nn.ReLU() self.fc2 = nn.Linear(num_features // reduction_ratio, num_features, bias=False) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): """ Forward pass of the channel attention module. Parameters ---------- x: torch.tensor Input tensor. Returns ------- torch.tensor Output tensor after applying the channel attention module. """ y = self.fc1(x) y = self.relu(y) y = self.fc2(y) y = self.sigmoid(y) return x * y.expand_as(x)