fusilli.fusionmodels.base_model

Base lightning module for all fusion models and parent class for all fusion models.

Classes

BaseModel(model[,Β metrics_list])

Base pytorch lightning model for all fusion models.

ParentFusionModel(prediction_task,Β ...)

Parent class for all fusion models.

class BaseModel(model, metrics_list=None)[source]

Bases: LightningModule

Base pytorch lightning model for all fusion models.

This class takes the specific fusion model as an input and provides the training and validation steps. The loss functions/metrics/activation function options are defined here and chosen based on the prediction type chosen by the user.

model

Fusion model class.

Type:

class

multiclass_dimensions

Number of classes for multiclass prediction. Default is 3 for making the metrics dictionary.

Type:

int

metrics

Dictionary of metrics, at least two. Key is the name and value is the function from MetricsCalculator.

Type:

dict

metrics_list

List of strings of names of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression).

Type:

list

train_mask

Mask for training data, used for the graph fusion methods instead of train/val split. Indicates which nodes are training nodes.

Type:

tensor

val_mask

Mask for validation data - used for the graph fusion methods instead of train/val split. Indicates which nodes are validation nodes.

Type:

tensor

loss_functions

Dictionary of loss functions, one for each prediction type.

Type:

dict

output_activation_functions

Dictionary of output activation functions, one for each prediction type.

Type:

dict

batch_val_reals

List of validation reals for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.

Type:

list

batch_val_preds

List of validation preds for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.

Type:

list

batch_val_logits

List of validation logits for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.

Type:

list

batch_train_reals

List of training reals for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.

Type:

list

batch_train_preds

List of training preds for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.

Type:

list

val_reals

Concatenated validation reals for all batches. Accessed by Plotter class for plotting.

Type:

tensor

val_preds

Concatenated validation preds for all batches. Accessed by Plotter class for plotting.

Type:

tensor

val_logits

Concatenated validation logits for all batches. Accessed by Plotter class for plotting.

Type:

tensor

train_reals

Concatenated training reals for all batches. Accessed by Plotter class for plotting.

Type:

tensor

train_preds

Concatenated training preds for all batches. Accessed by Plotter class for plotting.

Type:

tensor

__init__(model, metrics_list=None)[source]
Parameters:
  • model (class) – Fusion model class.

  • metrics_list (list or None) – List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models’ performances. Length must be 2 or more.

Return type:

None

configure_optimizers()[source]

Configure optimizers.

get_data_from_batch(batch)[source]

Get data from batch.

Parameters:

batch (tensor) – Batch of data.

Returns:

  • x (tensor) – Input data.

  • y (tensor) – Labels.

get_model_outputs(x)[source]

Get model outputs.

Parameters:

x (tensor) – Input data.

Returns:

  • logits (tensor) – Logits.

  • reconstructions (tensor) – Reconstructions (returned if the model has a custom loss function such as a subspace method)

Note

if you get an error here, check that the forward output in fusion model is [out,] or [out, reconstructions]

get_model_outputs_and_loss(x, y, train=True)[source]

Get model outputs and loss.

Parameters:
  • x (tensor) – Input data.

  • y (tensor) – Labels.

  • train (bool) – Whether the data is training data.

Returns:

  • loss (tensor) – Loss.

  • end_output (tensor) – Final output.

  • logits (tensor) – Logits.

on_validation_epoch_end()[source]

Gets the final validation epoch outputs and metrics. When metrics are calculated at the validation step and logged on on_epoch=True, the batch metrics are averaged. However, some metrics don’t average well (e.g. R2). Therefore, we’re calculating the final validation metrics here on the full validation set.

Parameters:

outputs (list) – List of outputs.

Return type:

None

predict_step(batch: Any, batch_idx: int, dataloader_idx: int = 0) Any[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
static safe_squeeze(tensor)[source]

Squeeze tensor if it is not 1D.

Parameters:

tensor (tensor) – Tensor to be squeezed.

Returns:

Squeezed tensor.

Return type:

tensor

set_metrics(metrics_list)[source]

Set what metrics will be used to log and plot. If no metrics are passed, then the default metrics for the prediction task will be used.

Parameters:

metrics_list (list or None) – List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models’ performances. Length must be 2 or more.

training_step(batch, batch_idx)[source]

Training step.

Parameters:
  • batch (tensor) – Batch of data.

  • batch_idx (int) – Batch index.

Returns:

loss – Loss.

Return type:

tensor

validation_step(batch, batch_idx)[source]

Validation step.

Parameters:
  • batch (tensor) – Batch of data.

  • batch_idx (int) – Batch index.

Return type:

None

class ParentFusionModel(prediction_task, data_dims, multiclass_dimensions)[source]

Bases: object

Parent class for all fusion models.

prediction_task

Type of prediction to be made. Options: binary, multiclass, regression.

Type:

str

mod1_dim

Dimension of modality 1.

Type:

int

mod2_dim

Dimension of modality 2.

Type:

int

img_dim

Dimensions of image modality. If using 2D images, then the dimensions will be (x, y). If using 3D images, then the dimensions will be (x, y, z).

Type:

tuple

multiclass_dimensions

Number of classes for multiclass prediction.

Type:

int

final_prediction

Final prediction layers.

Type:

nn.Sequential

mod1_layers

Modality 1 layers.

Type:

nn.ModuleDict

mod2_layers

Modality 2 layers.

Type:

nn.ModuleDict

img_layers

Image layers.

Type:

nn.ModuleDict

fused_layers

Fused layers.

Type:

nn.Sequential

__init__(prediction_task, data_dims, multiclass_dimensions)[source]
Parameters:
  • prediction_task (str) – Type of prediction to be made. Options: binary, multiclass, regression.

  • data_dims (list) – List of data dimensions.

  • multiclass_dimensions (int) – Number of classes for multiclass prediction.

set_final_pred_layers(input_dim=64)[source]

Sets final prediction layers.

Parameters:

input_dim (int) – Input dimension to final layers - may depend on fusion configuration.

Return type:

None

set_fused_layers(fused_dim)[source]

Set layers for fused modality

Parameters:

fused_dim (int) – Dimension of fused modality: how many features are there after fusion? e.g. if we have 2 modalities with 64 features each, and the fusion method was concatenation, the fused_dim would be 128

Return type:

None

set_img_layers()[source]

Sets layers for image modality. If using 2D images, then the layers will use Conv2D layers. If using 3D images, then the layers will use Conv3D layers.

Return type:

None

set_mod1_layers()[source]

Sets layers for modality 1

Return type:

None

set_mod2_layers()[source]

Sets layers for modality 2

Return type:

None