Source code for fusilli.fusionmodels.tabularimagefusion.concat_img_latent_tab_doubleloss

"""
Concat image latent space with tabular data, trained altogether with a custom loss function: MSE + BCE.
"""

import torch.nn as nn
from fusilli.fusionmodels.base_model import ParentFusionModel
import torch
from torch.autograd import Variable
import copy
from fusilli.utils import check_model_validity


[docs] class ConcatImgLatentTabDoubleLoss(ParentFusionModel, nn.Module): """ Concatenating image latent space with tabular data, trained altogether with a custom loss function: MSE + BCE. Attributes ---------- prediction_task : str Type of prediction to be performed. Binary, regression or multiclass. fused_layers : nn.Sequential Sequential layer containing the fused layers defined with :func:`calc_fused_layers()`. final_prediction : nn.Sequential Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. custom_loss : nn.Module Additional loss function to be used for training the model. Default is MSELoss. latent_dim : int Size of the latent space. Default is 256. encoder : nn.Sequential Sequential layer containing the encoder layers. Default for 2D image is: .. code-block:: python nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) decoder : nn.Sequential Sequential layer containing the decoder layers. Default for 2D image is: .. code-block:: python nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2), ) new_encoder : nn.Sequential Sequential layer containing the encoder layers and the additional layers defined with :meth:`~ConcatImgLatentTabDoubleLoss.calc_fused_layers()`. new_decoder : nn.Sequential Sequential layer containing the decoder layers and the additional layers defined with :meth:`~ConcatImgLatentTabDoubleLoss.calc_fused_layers()`. fused_dim : int Size of the fused layers: latent dimension size + tabular data dimension size. """ #: str: Name of the method. method_name = "Trained Together Latent Image + Tabular Data" #: str: Type of modality. modality_type = "tabular_image" #: str: Type of fusion. fusion_type = "subspace"
[docs] def __init__(self, prediction_task, data_dims, multiclass_dimensions): """ Parameters ---------- prediction_task : str Type of prediction to be performed. data_dims : list List containing the dimensions of the data. multiclass_dimensions : int Number of classes in the multiclass classification task. """ ParentFusionModel.__init__( self, prediction_task, data_dims, multiclass_dimensions ) self.prediction_task = prediction_task self.custom_loss = nn.MSELoss() self.img_dim = data_dims[-1] self.latent_dim = 256 # You can adjust the latent space size if len(self.img_dim) == 2: # 2D images self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), # 100x100x1 -> 100x100x32 nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 100x100x32 -> 50x50x32 nn.Conv2d(32, 64, kernel_size=3, padding=1), # 50x50x32 -> 50x50x64 nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 50x50x64 -> 25x25x64 nn.Conv2d(64, 128, kernel_size=3, padding=1), # 25x25x64 -> 25x25x128 nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 25x25x128 -> 12x12x128 ) self.decoder = nn.Sequential( nn.ConvTranspose2d( 128, 64, kernel_size=2, stride=2 ), # 12x12x128 -> 25x25x64 nn.ReLU(), nn.ConvTranspose2d( 64, 32, kernel_size=2, stride=2 ), # 25x25x64 -> 50x50x32 nn.ReLU(), nn.ConvTranspose2d( 32, 1, kernel_size=2, stride=2 ), # 50x50x32 -> 100x100x1 ) elif len(self.img_dim) == 3: self.encoder = nn.Sequential( nn.Conv3d(1, 16, kernel_size=3, stride=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2), nn.Conv3d(16, 32, kernel_size=3, stride=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2), nn.Conv3d(32, self.latent_dim, kernel_size=3, stride=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2), ) self.decoder = nn.Sequential( # nn.ConvTranspose3d(256, 128, kernel_size=3, stride=1, output_padding=1), # nn.ReLU(), nn.ConvTranspose3d(self.latent_dim, 32, kernel_size=3, stride=1), nn.ReLU(), nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1), nn.ReLU(), nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1), ) else: raise ValueError("Invalid image dimensions.") self.get_fused_dim() self.set_fused_layers(self.fused_dim) self.calc_fused_layers()
[docs] def get_fused_dim(self): """ Get the number of features of the fused layers. Returns ------- None """ self.fused_dim = self.latent_dim + self.mod1_dim
[docs] def calc_fused_layers(self): """ Calculate the fused layers. If layer sizes are modified, this function will be called again to adjust the fused layers. Returns ------- None """ check_model_validity.check_dtype(self.encoder, nn.Sequential, "encoder") check_model_validity.check_dtype(self.decoder, nn.Sequential, "decoder") check_model_validity.check_dtype(self.latent_dim, int, "latent_dim") self.get_fused_dim() # check fused layers self.fused_layers, out_dim = check_model_validity.check_fused_layers( self.fused_layers, self.fused_dim ) if self.latent_dim < 1: raise ValueError( f"Incorrect attribute range: The latent dimension must be greater than 0. The latent dimension is " f"currently: ", self.latent_dim, ) check_model_validity.check_img_dim(self.encoder, self.img_dim, "encoder") check_model_validity.check_img_dim(self.decoder, self.img_dim, "decoder") # size of final encoder output dummy_conv_output = Variable(torch.rand((1,) + tuple(self.img_dim))) dummy_conv_output = self.encoder(dummy_conv_output) n_size = dummy_conv_output.data.view(1, -1).size(1) # add extra layers to encoder self.new_encoder = copy.deepcopy(self.encoder) self.new_encoder.append(nn.Flatten()) self.new_encoder.append(nn.Linear(n_size, self.latent_dim)) # add extra layer to decoder to get right shape for first decoding layer self.new_decoder = copy.deepcopy(self.decoder) first_decoder_layer_inchannels = self.new_decoder[0].in_channels self.new_decoder.insert( 0, nn.Linear(self.latent_dim, first_decoder_layer_inchannels) ) if len(self.img_dim) == 3: self.new_decoder.insert( 1, nn.Unflatten(1, (first_decoder_layer_inchannels, 1, 1, 1)) ) elif len(self.img_dim) == 2: self.new_decoder.insert( 1, nn.Unflatten(1, (first_decoder_layer_inchannels, 1, 1)) ) self.new_decoder.append(nn.Sigmoid()), # Output is scaled between 0 and 1 if len(self.img_dim) == 3: self.new_decoder.append( nn.Upsample(size=self.img_dim, mode="trilinear", align_corners=False) ) elif len(self.img_dim) == 2: self.new_decoder.append( nn.Upsample(size=self.img_dim, mode="bilinear", align_corners=False) ) self.set_final_pred_layers(out_dim)
[docs] def forward(self, x1, x2): """ Forward pass of the model. Parameters ---------- x1 : torch.Tensor Input tensor for the tabular data. x2 : torch.Tensor Input tensor for the image data. Returns ------- list : list List containing the output data: prediction and reconstructed image. [ [prediction], [reconstructed_image] ] """ check_model_validity.check_model_input(x1) check_model_validity.check_model_input(x2) x_tab = x1 x_img = x2 # encoder encoded_img = self.new_encoder(x_img) # latent space latent_space = encoded_img # decoder reconstructed_image = self.new_decoder(encoded_img) reconstructed_image = torch.sigmoid(reconstructed_image) # concatenate latent space with tabular data fused_data = torch.cat([latent_space, x_tab], dim=1) # put fused data through some joint layers out_fuse = self.fused_layers(fused_data) # final prediction out = self.final_prediction(out_fuse) return [out, reconstructed_image]