"""
Activation-function fusion model for tabular data.
"""
import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
from fusilli.utils import check_model_validity
[docs]
class ActivationFusion(ParentFusionModel, nn.Module):
"""
Performs an element wise product of the feature maps of the two tabular modalities,
tanh activation function and sigmoid activation function. Afterwards the the first tabular modality feature
map is concatenated with the fused feature map.
Attributes
----------
prediction_task : str
Type of prediction to be performed.
mod1_layers : nn.ModuleDict
Dictionary containing the layers of the first modality. Calculated in the
:meth:`~ParentFusionModel.set_mod1_layers` method.
mod2_layers : nn.ModuleDict
Dictionary containing the layers of the second modality. Calculated in the
:meth:`~ParentFusionModel.set_mod2_layers` method.
fused_dim : int
Number of features of the fused layers. In this method, it's the size of the tabular 1
layers output plus the size of the tabular 2 layers output.
fused_layers : nn.Sequential
Sequential layer containing the fused layers. Calculated in the
:meth:`~ActivationFusion.calc_fused_layers` method.
final_prediction : nn.Sequential
Sequential layer containing the final prediction layers. The final prediction layers
take in the number of features of the fused layers as input. Calculated in the
:meth:`~ActivationFusion.calc_fused_layers` method.
"""
#: str: Name of the method.
method_name = "Activation function map fusion"
#: str: Type of modality.
modality_type = "tabular_tabular"
#: str: Type of fusion.
fusion_type = "operation"
[docs]
def __init__(self, prediction_task, data_dims, multiclass_dimensions):
"""
Parameters
----------
prediction_task : str
Type of prediction to be performed.
data_dims : list
List containing the dimensions of the data.
multiclass_dimensions : int
Number of classes in the multiclass classification problem.
"""
ParentFusionModel.__init__(
self, prediction_task, data_dims, multiclass_dimensions
)
self.prediction_task = prediction_task
self.set_mod1_layers()
self.set_mod2_layers()
self.get_fused_dim()
self.set_fused_layers(self.fused_dim)
self.calc_fused_layers()
[docs]
def get_fused_dim(self):
"""
Get the number of features of the fused layers.
Assuming mod1_layers and mod2_layers output the same dimension.
"""
mod1_output_dim = list(self.mod1_layers.values())[-1][0].out_features
mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features
# New fused dimension is the sum of mod1 and mod2 output dimensions
self.fused_dim = mod1_output_dim + mod2_output_dim
[docs]
def calc_fused_layers(self):
"""
Calculate the fused layers.
Returns
-------
None
"""
# ~~ Checks ~~
check_model_validity.check_dtype(self.mod1_layers, nn.ModuleDict, "mod1_layers")
check_model_validity.check_dtype(self.mod2_layers, nn.ModuleDict, "mod2_layers")
mod1_output_dim = list(self.mod1_layers.values())[-1][0].out_features
mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features
if mod1_output_dim != mod2_output_dim:
raise UserWarning(
"The number of output features of mod1_layers and mod2_layers must be the same for Activation fusion. Please change the final layers in the modality layers to have the same number of output features as each other."
)
self.get_fused_dim()
self.fused_layers, out_dim = check_model_validity.check_fused_layers(
self.fused_layers, self.fused_dim
)
# setting final prediction layers with final out features of fused layers
self.set_final_pred_layers(out_dim)
[docs]
def forward(self, x1, x2):
"""
Forward pass of the model.
Parameters
----------
x1 : torch.Tensor
Input tensor of the first modality.
x2 : torch.Tensor
Input tensor of the second modality.
Returns
-------
out : torch.Tensor
Fused prediction.
"""
# ~~ Checks ~~
check_model_validity.check_model_input(x1)
check_model_validity.check_model_input(x2)
for layer in self.mod1_layers.values():
x1 = layer(x1)
for layer in self.mod2_layers.values():
x2 = layer(x2)
x1 = torch.squeeze(x1, 1)
x2 = torch.squeeze(x2, 1)
out_fuse = torch.mul(x1, x2)
out_fuse = torch.tanh(out_fuse)
out_fuse = torch.sigmoid(out_fuse)
out_fuse = torch.cat((out_fuse, x1), dim=1)
out_fuse = self.fused_layers(out_fuse)
out = self.final_prediction(out_fuse)
return out
"""
"""