fusilli.fusionmodels.tabularimagefusion.concat_img_latent_tab_doubletrainο
Concatenating the latent img space and tabular data. The latent image space is trained separately from both the tabular data and the labels, using the img_latent_subspace_method class.
Classes
Concatenating the latent img space and tabular data. |
|
|
Pytorch lightning module: autoencoder to train the latent image space. |
|
Class containing the method to train the latent image space and to convert the image data to the latent image space. |
- class ConcatImgLatentTabDoubleTrain(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
ParentFusionModel,ModuleConcatenating the latent img space and tabular data. The latent image space is trained separately from both the tabular data and the labels, using the img_latent_subspace_method class.
- prediction_taskο
Type of prediction to be performed.
- Type:
str
- latent_dimο
Dimension of the latent image space once we encode it down. Taken from the subspace_method class and inferred from the dimensions of the input data to the model.
- Type:
int
- enc_img_layerο
Linear layer to reduce the dimension of the latent image space. Calculated with
calc_fused_layers().- Type:
nn.Linear
- fused_dimο
Dimension of the fused layers. Calculated with
calc_fused_layers().- Type:
int
- fused_layersο
Sequential layer containing the fused layers. Calculated with
calc_fused_layers().- Type:
nn.Sequential
- final_predictionο
Sequential layer containing the final prediction layers. The final prediction layers take in the number of features of the fused layers as input. Calculated with
calc_fused_layers().- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be performed.
data_dims (list) β List containing the dimensions of the data.
multiclass_dimensions (int) β Number of classes in the multiclass classification task.
- calc_fused_layers()[source]ο
Calculate the fused layers. If layer sizes are modified, this function will be called again to adjust the fused layers.
- Return type:
None
- forward(x1, x2)[source]ο
Forward pass of the model.
- Parameters:
x1 (torch.Tensor) β Input tensor containing the tabular data.
x2 (torch.Tensor) β Input tensor containing the latent image space.
- Returns:
Output tensor.
- Return type:
torch.Tensor
- fusion_type = 'subspace'ο
Type of fusion.
- Type:
str
- method_name = 'Pretrained Latent Image + Tabular Data'ο
Name of the method.
- Type:
str
- modality_type = 'tabular_image'ο
Type of modality.
- Type:
str
- subspace_methodο
class: Class containing the method to train the latent image space.
- class ImgLatentSpace(data_dims)[source]ο
Bases:
LightningModulePytorch lightning module: autoencoder to train the latent image space.
- data_dimsο
List containing the dimensions of the data.
- Type:
dict
- img_dimο
Dimensions of the image data.
- Type:
tuple
- latent_dimο
Dimension of the latent image space once we encode it down. Default is 64.
- Type:
int
- encoderο
Sequential layer containing the encoder layers.
- Type:
nn.Sequential
- decoderο
Sequential layer containing the decoder layers.
- Type:
nn.Sequential
- new_encoderο
Sequential layer containing the encoder layers and the linear layer to reduce the dimension of the latent image space. Calculated with
calc_fused_layers().- Type:
nn.Sequential
- new_decoderο
Sequential layer containing the decoder layers and the linear layer to increase the dimension of the latent image space. Calculated with
calc_fused_layers().- Type:
nn.Sequential
- __init__(data_dims)[source]ο
- Parameters:
data_dims (list) β List containing the dimensions of the data.
- calc_fused_layers()[source]ο
Calculate the fused layers. If layer sizes are modified, this function will be called again to adjust the fused layers.
- Return type:
None
- configure_optimizers()[source]ο
Configure the optimizers of the model.
- Returns:
Adam optimizer.
- Return type:
torch.optim.Adam
- encode_image(x)[source]ο
Encode the image data. Used when the model is trained to get latent image space.
- Parameters:
x (torch.Tensor) β Input data.
- Returns:
Encoded image.
- Return type:
torch.Tensor
- forward(x)[source]ο
Forward pass of the model.
- Parameters:
x (torch.Tensor) β Input data.
- Returns:
Output of the model. Reconstruction/decoded image.
- Return type:
torch.Tensor
- class concat_img_latent_tab_subspace_method(datamodule, k=None, max_epochs=1000, train_subspace=True)[source]ο
Bases:
objectClass containing the method to train the latent image space and to convert the image data to the latent image space.
- datamoduleο
Data module containing the data.
- Type:
pl.LightningDataModule
- trainerο
Lightning trainer.
- Type:
pytorch_lightning.Trainer
- autoencoderο
Autoencoder to train the latent image space.
- Type:
- __init__(datamodule, k=None, max_epochs=1000, train_subspace=True)[source]ο
- Parameters:
datamodule (class) β Data module containing the data.
k (int or None) β Number of folds for cross validation. Default is None.
max_epochs (int) β Maximum number of epochs to train the latent image space.
train_subspace (bool) β Whether to train the latent image space or not. Default is True. If False, a new trainer will not be created. Then load_ckpt() must be called to load the checkpoint of the latent image space.
- convert_to_latent(test_dataset)[source]ο
Convert the image data to the latent image space.
- Parameters:
test_dataset (Dataset) β Test dataset.
- Returns:
list β List containing the raw tabular data and the latent image space.
pd.DataFrame β Dataframe containing the labels.
list β List containing the dimensions of the data.
- load_ckpt(checkpoint_path)[source]ο
Load the checkpoint of the latent image space.
- Parameters:
checkpoint_path (list) β List containing the path to the checkpoint.
- train(train_dataset, val_dataset)[source]ο
Train the latent image space.
- Parameters:
train_dataset (Dataset) β Training dataset.
val_dataset (Dataset) β Validation dataset.
- Returns:
list β List containing the raw tabular data and the latent image space.
pd.DataFrame β Dataframe containing the labels.