Source code for fusilli.fusionmodels.tabularimagefusion.decision

"""
Model that fuses the first tabular data and the image data using a decision
fusion approach.
"""

import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
from torch.autograd import Variable

from fusilli.utils import check_model_validity


[docs] class ImageDecision(ParentFusionModel, nn.Module): """ This class implements a model that fuses the first tabular data and the image data using a decision fusion approach. Attributes ---------- mod1_layers : nn.ModuleDict Dictionary containing the layers of the 1st type of tabular data. img_layers : nn.ModuleDict Dictionary containing the layers of the image data. fused_layers : nn.Sequential Sequential layer containing the fused layers. final_prediction_tab1 : nn.Sequential Sequential layer containing the final prediction layers for the first tabular data. final_prediction_img : nn.Sequential Sequential layer containing the final prediction layers for the image data. fusion_operation : function Function that performs the fusion operation. Default is torch.mean(torch.stack([x, y]), dim=0). .. warning:: `fusion_operation` should be done on the first dimension, i.e. the batch dimension. For example, `lambda x: torch.mean(x, dim=1)`. The predictions of the different modalities are stacked on the first dimension before `fusion_operation`. """ #: str: Name of the method. method_name = "Image decision fusion" #: str: Type of modality. modality_type = "tabular_image" #: str: Type of fusion. fusion_type = "operation"
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes in the multiclass classification task. """ ParentFusionModel.__init__( self, prediction_task, data_dims, multiclass_dimensions ) self.prediction_task = prediction_task # self.fusion_operation = lambda x: torch.mean(x, dim=1) self.fusion_operation = lambda x, y: torch.mean(torch.stack([x, y]), dim=0) self.set_img_layers() self.set_mod1_layers() self.calc_fused_layers()
[docs] def calc_fused_layers(self): """ Calculates the fusion layers. """ check_model_validity.check_var_is_function( self.fusion_operation, "fusion_operation" ) check_model_validity.check_dtype(self.img_layers, nn.ModuleDict, "img_layers") check_model_validity.check_dtype(self.mod1_layers, nn.ModuleDict, "mod1_layers") check_model_validity.check_img_dim(self.img_layers, self.img_dim, "img_layers") # ~~ Tabular data ~~ tab_fused_dim = list(self.mod1_layers.values())[-1][0].out_features self.set_final_pred_layers(tab_fused_dim) self.final_prediction_tab1 = self.final_prediction # ~~ Image data ~~ dummy_conv_output = Variable(torch.rand((1,) + tuple(self.img_dim))) for layer in self.img_layers.values(): dummy_conv_output = layer(dummy_conv_output) img_fusion_size = dummy_conv_output.data.view(1, -1).size(1) self.set_final_pred_layers(img_fusion_size) self.final_prediction_img = self.final_prediction
[docs] def forward(self, x1, x2): """ Forward pass of the model. Parameters ---------- x1 : torch.Tensor First tabular data input. x2 : torch.Tensor Image data input. Returns ------- torch.Tensor Fused prediction. """ # ~~ Checks ~~ check_model_validity.check_model_input(x1) check_model_validity.check_model_input(x2) x_tab1 = x1.squeeze(dim=1) x_img = x2 for i, (k, layer) in enumerate(self.mod1_layers.items()): x_tab1 = layer(x_tab1) for i, (k, layer) in enumerate(self.img_layers.items()): x_img = layer(x_img) x_img = x_img.view(x_img.size(0), -1) # predictions for each method pred_tab1 = self.final_prediction_tab1(x_tab1) pred_img = self.final_prediction_img(x_img) # Combine predictions by averaging them together # fusion_input = torch.stack((pred_tab1, pred_img), dim=1) # out_fuse = self.fusion_operation(fusion_input) out_fuse = self.fusion_operation(pred_tab1, pred_img) return out_fuse