fusilli.fusionmodels.base_modelο
Base lightning module for all fusion models and parent class for all fusion models.
Classes
|
Base pytorch lightning model for all fusion models. |
|
Parent class for all fusion models. |
- class BaseModel(model, metrics_list=None)[source]ο
Bases:
LightningModule
Base pytorch lightning model for all fusion models.
This class takes the specific fusion model as an input and provides the training and validation steps. The loss functions/metrics/activation function options are defined here and chosen based on the prediction type chosen by the user.
- modelο
Fusion model class.
- Type:
class
- multiclass_dimensionsο
Number of classes for multiclass prediction. Default is 3 for making the metrics dictionary.
- Type:
int
- metricsο
Dictionary of metrics, at least two. Key is the name and value is the function from MetricsCalculator.
- Type:
dict
- metrics_listο
List of strings of names of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression).
- Type:
list
- train_maskο
Mask for training data, used for the graph fusion methods instead of train/val split. Indicates which nodes are training nodes.
- Type:
tensor
- val_maskο
Mask for validation data - used for the graph fusion methods instead of train/val split. Indicates which nodes are validation nodes.
- Type:
tensor
- loss_functionsο
Dictionary of loss functions, one for each prediction type.
- Type:
dict
- output_activation_functionsο
Dictionary of output activation functions, one for each prediction type.
- Type:
dict
- batch_val_realsο
List of validation reals for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.
- Type:
list
- batch_val_predsο
List of validation preds for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.
- Type:
list
- batch_val_logitsο
List of validation logits for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.
- Type:
list
- batch_train_realsο
List of training reals for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.
- Type:
list
- batch_train_predsο
List of training preds for each batch. Stored for later concatenation with rest of batches and access by Plotter class for plotting.
- Type:
list
- val_realsο
Concatenated validation reals for all batches. Accessed by Plotter class for plotting.
- Type:
tensor
- val_predsο
Concatenated validation preds for all batches. Accessed by Plotter class for plotting.
- Type:
tensor
- val_logitsο
Concatenated validation logits for all batches. Accessed by Plotter class for plotting.
- Type:
tensor
- train_realsο
Concatenated training reals for all batches. Accessed by Plotter class for plotting.
- Type:
tensor
- train_predsο
Concatenated training preds for all batches. Accessed by Plotter class for plotting.
- Type:
tensor
- __init__(model, metrics_list=None)[source]ο
- Parameters:
model (class) β Fusion model class.
metrics_list (list or None) β List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the modelsβ performances. Length must be 2 or more.
- Return type:
None
- get_data_from_batch(batch)[source]ο
Get data from batch.
- Parameters:
batch (tensor) β Batch of data.
- Returns:
x (tensor) β Input data.
y (tensor) β Labels.
- get_model_outputs(x)[source]ο
Get model outputs.
- Parameters:
x (tensor) β Input data.
- Returns:
logits (tensor) β Logits.
reconstructions (tensor) β Reconstructions (returned if the model has a custom loss function such as a subspace method)
Note
if you get an error here, check that the forward output in fusion model is [out,] or [out, reconstructions]
- get_model_outputs_and_loss(x, y, train=True)[source]ο
Get model outputs and loss.
- Parameters:
x (tensor) β Input data.
y (tensor) β Labels.
train (bool) β Whether the data is training data.
- Returns:
loss (tensor) β Loss.
end_output (tensor) β Final output.
logits (tensor) β Logits.
- on_validation_epoch_end()[source]ο
Gets the final validation epoch outputs and metrics. When metrics are calculated at the validation step and logged on on_epoch=True, the batch metrics are averaged. However, some metrics donβt average well (e.g. R2). Therefore, weβre calculating the final validation metrics here on the full validation set.
- Parameters:
outputs (list) β List of outputs.
- Return type:
None
- predict_step(batch: Any, batch_idx: int, dataloader_idx: int = 0) Any [source]ο
Step function called during
predict()
. By default, it callsforward()
. Override to add any processing logic.The
predict_step()
is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWriter
callback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWriter
should be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")
or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)
as predictions wonβt be returned.- Parameters:
batch β The output of your data iterable, normally a
DataLoader
.batch_idx β The index of this batch.
dataloader_idx β The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- static safe_squeeze(tensor)[source]ο
Squeeze tensor if it is not 1D.
- Parameters:
tensor (tensor) β Tensor to be squeezed.
- Returns:
Squeezed tensor.
- Return type:
tensor
- set_metrics(metrics_list)[source]ο
Set what metrics will be used to log and plot. If no metrics are passed, then the default metrics for the prediction task will be used.
- Parameters:
metrics_list (list or None) β List of metrics to use for model evaluation. Default None. If None, the metrics will be automatically selected based on the prediction task (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the modelsβ performances. Length must be 2 or more.
- class ParentFusionModel(prediction_task, data_dims, multiclass_dimensions)[source]ο
Bases:
object
Parent class for all fusion models.
- prediction_taskο
Type of prediction to be made. Options: binary, multiclass, regression.
- Type:
str
- mod1_dimο
Dimension of modality 1.
- Type:
int
- mod2_dimο
Dimension of modality 2.
- Type:
int
- img_dimο
Dimensions of image modality. If using 2D images, then the dimensions will be (x, y). If using 3D images, then the dimensions will be (x, y, z).
- Type:
tuple
- multiclass_dimensionsο
Number of classes for multiclass prediction.
- Type:
int
- final_predictionο
Final prediction layers.
- Type:
nn.Sequential
- mod1_layersο
Modality 1 layers.
- Type:
nn.ModuleDict
- mod2_layersο
Modality 2 layers.
- Type:
nn.ModuleDict
- img_layersο
Image layers.
- Type:
nn.ModuleDict
- fused_layersο
Fused layers.
- Type:
nn.Sequential
- __init__(prediction_task, data_dims, multiclass_dimensions)[source]ο
- Parameters:
prediction_task (str) β Type of prediction to be made. Options: binary, multiclass, regression.
data_dims (list) β List of data dimensions.
multiclass_dimensions (int) β Number of classes for multiclass prediction.
- set_final_pred_layers(input_dim=64)[source]ο
Sets final prediction layers.
- Parameters:
input_dim (int) β Input dimension to final layers - may depend on fusion configuration.
- Return type:
None
- set_fused_layers(fused_dim)[source]ο
Set layers for fused modality
- Parameters:
fused_dim (int) β Dimension of fused modality: how many features are there after fusion? e.g. if we have 2 modalities with 64 features each, and the fusion method was concatenation, the fused_dim would be 128
- Return type:
None