fusilli.utils.model_chooserļ
This module contains the function to filter the fusion models based on the conditions specified by the user. Conditions are specified in a dictionary, where the keys are the features to filter by and the values are the conditions to filter by. The function returns a dataframe containing the filtered models.
Functions
|
Imports all the fusion models in the fusion_model_dict. |
|
Filters the models based on the conditions specified by the user. |
|
Imports the fusion models specified by the user. |
- all_model_importer(fusion_model_dict, skip_models=None)[source]ļ
Imports all the fusion models in the fusion_model_dict.
- Parameters:
fusion_model_dict (list) ā List of dictionaries containing all the fusion modelsā names and paths. Names mean the name of the class, and paths mean the path to the .py file containing the class. Note: this must be updated whenever a new fusion model is added.
skip_models (list) ā List of models to skip when importing. Default is None. The list should consist of the class names of the models to skip e.g. [āTabularDecisionā, āImgUnimodalā]. You might skip models if some are not working properly for you.
- Returns:
fusion_models (list) ā List of all the fusion models class objects
fusion_model_dict_copy (list) ā List of dictionaries containing all the fusion modelsā names and paths, without the models that were skipped.
- get_models(conditions_dict, skip_models=None, fusion_model_dict=[{'name': 'Tabular1Unimodal', 'path': 'fusionmodels.unimodal.tabular1'}, {'name': 'Tabular2Unimodal', 'path': 'fusionmodels.unimodal.tabular2'}, {'name': 'ImgUnimodal', 'path': 'fusionmodels.unimodal.image'}, {'name': 'ConcatTabularFeatureMaps', 'path': 'fusionmodels.tabularfusion.concat_feature_maps'}, {'name': 'ConcatImageMapsTabularData', 'path': 'fusionmodels.tabularimagefusion.concat_img_maps_tabular_data'}, {'name': 'ConcatTabularData', 'path': 'fusionmodels.tabularfusion.concat_data'}, {'name': 'ConcatImageMapsTabularMaps', 'path': 'fusionmodels.tabularimagefusion.concat_img_maps_tabular_maps'}, {'name': 'TabularChannelWiseMultiAttention', 'path': 'fusionmodels.tabularfusion.channelwise_att'}, {'name': 'ImageChannelWiseMultiAttention', 'path': 'fusionmodels.tabularimagefusion.channelwise_att'}, {'name': 'CrossmodalMultiheadAttention', 'path': 'fusionmodels.tabularimagefusion.crossmodal_att'}, {'name': 'TabularCrossmodalMultiheadAttention', 'path': 'fusionmodels.tabularfusion.crossmodal_att'}, {'name': 'TabularDecision', 'path': 'fusionmodels.tabularfusion.decision'}, {'name': 'ImageDecision', 'path': 'fusionmodels.tabularimagefusion.decision'}, {'name': 'MCVAE_tab', 'path': 'fusionmodels.tabularfusion.mcvae_model'}, {'name': 'ConcatImgLatentTabDoubleTrain', 'path': 'fusionmodels.tabularimagefusion.concat_img_latent_tab_doubletrain'}, {'name': 'ConcatImgLatentTabDoubleLoss', 'path': 'fusionmodels.tabularimagefusion.concat_img_latent_tab_doubleloss'}, {'name': 'EdgeCorrGNN', 'path': 'fusionmodels.tabularfusion.edge_corr_gnn'}, {'name': 'DAETabImgMaps', 'path': 'fusionmodels.tabularimagefusion.denoise_tab_img_maps'}, {'name': 'AttentionWeightedGNN', 'path': 'fusionmodels.tabularfusion.attention_weighted_GNN'}, {'name': 'AttentionAndSelfActivation', 'path': 'fusionmodels.tabularfusion.attention_and_activation'}, {'name': 'ActivationFusion', 'path': 'fusionmodels.tabularfusion.activation'}])[source]ļ
Filters the models based on the conditions specified by the user.
- Parameters:
conditions_dict (dict) ā
Dictionary containing the conditions to filter the models. Structure: {feature1: condition, feature2: condition, ā¦} or {feature1: [condition1, condition2, ā¦], feature2: [condition1, ā¦], ā¦}
Accepted features and accepted conditions:
āfusion_typeā: āunimodalā, āoperationā, āattentionā, āsubspaceā, āgraphā, or āallā
āmodality_typeā: ātabular1ā, ātabular2ā, āimgā, ātabular_tabularā, ātabular_imageā, or āallā
āmethod_nameā: any method name currently implemented (e.g. āTabular decisionā), or āallā
āclass_nameā: any model name currently implemented (e.g. āTabularDecisionā), or āallā
Example: To get all the models that are uni-modal and attention-based, the dictionary would be:
conditions_dict = { "fusion_type": ["unimodal", "operation"], "modality_type": "all", }
fusion_model_dict (list) ā List of dictionaries containing the fusion modelsā names and paths. Default is fusion_model_dict.
skip_models (list) ā List of models to skip when importing. Default is None. The list should consist of the class names of the models to skip e.g. [āTabularDecisionā, āImgUnimodalā]. You might skip models if some are not working properly for you.
- Returns:
filtered_models ā Dataframe containing the filtered models.
Column names:
āmethod_nameā: name of the model (e.g. āTabular decisionā)
āfusion_typeā: type of fusion (e.g. āoperationā)
āmodality_typeā: type of modality (e.g. ātabular_tabularā)
āclass_nameā: name of the class (e.g. āTabularDecisionā)
āmethod_pathā: path to the methodās py file (e.g. āfusilli.fusionmodels.tabular_decisionā)
- Return type:
pd.DataFrame
- import_chosen_fusion_models(model_conditions, skip_models=None)[source]ļ
Imports the fusion models specified by the user.
- Parameters:
model_conditions (dict) ā
Dictionary containing the conditions to filter the models. Structure: {feature1: condition, feature2: condition, ā¦} or {feature1: [condition1, condition2, ā¦], feature2: [condition1, ā¦], ā¦}
Accepted features and accepted conditions:
āfusion_typeā: āunimodalā, āoperationā, āattentionā, āsubspaceā, āgraphā, or āallā
āmodality_typeā: ātabular1ā, ātabular2ā, āimgā, ātabular_tabularā, ātabular_imageā, or āallā
āmethod_nameā: any method name currently implemented (e.g. āTabular decisionā), or āallā
āclass_nameā: any model name currently implemented (e.g. āTabularDecisionā), or āallā
Example: To get all the models that are uni-modal and attention-based, the dictionary would be:
conditions_dict = { "fusion_type": ["unimodal", "operation"], "modality_type": "all", }
skip_models (list) ā List of models to skip when importing. Default is None. The list should consist of the class names of the models to skip e.g. [āTabularDecisionā, āImgUnimodalā]. You might skip models if some are not working properly for you.
- Returns:
fusion_models ā List of all the fusion models class objects
- Return type:
list