"""
Decision fusion of two types of tabular data.
"""
import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
from fusilli.utils import check_model_validity
[docs]
class TabularDecision(ParentFusionModel, nn.Module):
"""
This class implements a model that fuses the two types of tabular data using a decision fusion
approach.
Attributes
----------
mod1_layers : nn.ModuleDict
Dictionary containing the layers of the 1st type of tabular data.
mod2_layers : nn.ModuleDict
Dictionary containing the layers of the 2nd type of tabular data.
fused_layers : nn.Sequential
Sequential layer containing the fused layers.
final_prediction_tab1 : nn.Sequential
Sequential layer containing the final prediction layers for the first tabular data.
final_prediction_tab2 : nn.Sequential
Sequential layer containing the final prediction layers for the second tabular data.
fusion_operation : function
Function that performs the fusion operation. Default is torch.mean(torch.stack([x, y]), dim=0).
"""
#: str: Name of the method.
method_name = "Tabular decision"
#: str: Type of modality.
modality_type = "tabular_tabular"
#: str: Type of fusion.
fusion_type = "operation"
[docs]
def __init__(self, prediction_task, data_dims, multiclass_dimensions):
"""
Parameters
----------
prediction_task : str
Type of prediction to be performed.
data_dims : list
List containing the dimensions of the data.
multiclass_dimensions : int
Number of classes in the multiclass classification task.
"""
ParentFusionModel.__init__(
self, prediction_task, data_dims, multiclass_dimensions
)
self.prediction_task = prediction_task
self.fusion_operation = lambda x, y: torch.mean(torch.stack([x, y]), dim=0)
self.set_mod1_layers()
self.set_mod2_layers()
self.calc_fused_layers()
[docs]
def calc_fused_layers(self):
"""
Calculates the fusion layers.
Returns
-------
None
"""
check_model_validity.check_var_is_function(
self.fusion_operation, "fusion_operation"
)
check_model_validity.check_dtype(self.mod1_layers, nn.ModuleDict, "mod1_layers")
check_model_validity.check_dtype(self.mod2_layers, nn.ModuleDict, "mod2_layers")
tab1_fused_dim = list(self.mod1_layers.values())[-1][0].out_features
self.set_final_pred_layers(tab1_fused_dim)
self.final_prediction_tab1 = self.final_prediction
tab2_fused_dim = list(self.mod2_layers.values())[-1][0].out_features
self.set_final_pred_layers(tab2_fused_dim)
self.final_prediction_tab2 = self.final_prediction
[docs]
def forward(self, x1, x2):
"""
Forward pass of the model.
Parameters
----------
x1 : torch.Tensor
Input tensor for the first modality.
x2 : torch.Tensor
Input tensor for the second modality.
Returns
-------
torch.Tensor
Output tensor.
"""
# ~~ Checks ~~
check_model_validity.check_model_input(x1)
check_model_validity.check_model_input(x2)
x_tab1 = x1
x_tab2 = x2
for i, (k, layer) in enumerate(self.mod1_layers.items()):
x_tab1 = layer(x_tab1)
for i, (k, layer) in enumerate(self.mod2_layers.items()):
x_tab2 = layer(x_tab2)
# predictions for each method
pred_tab1 = self.final_prediction_tab1(x_tab1)
pred_tab2 = self.final_prediction_tab2(x_tab2)
# Combine predictions by averaging them together
out_fuse = self.fusion_operation(pred_tab1, pred_tab2)
# out_fuse = torch.mean(torch.stack([pred_tab1, pred_tab2]), dim=0)
return out_fuse