Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

AngelosNal/Vision-DiffMask

Repository files navigation

VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking

πŸ“ƒ [Paper] πŸš€ [Demo] πŸ’Ύ [Checkpoints]

This repository contains the official PyTorch implementation of the paper "VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking" by Angelos Nalmpantis*, Apostolos Panagiotopoulos*, John Gkountouras*, Konstantinos Papakostas* and Wilker Aziz (CVPRW XAI4CV 2023)

Overview

Vision DiffMask is a post-hoc interpretation method for vision tasks. Given a pre-trained model, it predicts the minimal subset of the input required to maintain the original output distribution. Currently, only Vision Transformer (ViT) for image classification is supported.

Alt text

Setup

We provide a conda environment for the installation of the required packages.

conda env create -f environment.yml

Project Structure

The project is organized in the following way:

. 
β”œβ”€β”€ code 
β”‚  β”œβ”€β”€ attributions/ 
β”‚  β”œβ”€β”€ datamodules
β”‚  β”‚  β”œβ”€β”€ base.py 
β”‚  β”‚  β”œβ”€β”€ image_classification.py
β”‚  β”‚  β”œβ”€β”€ transformations.py
β”‚  β”‚  β”œβ”€β”€ utils.py
β”‚  β”‚  └── visual_qa.py
β”‚  β”œβ”€β”€ eval_base.py
β”‚  β”œβ”€β”€ main.py
β”‚  β”œβ”€β”€ models
β”‚  β”‚  β”œβ”€β”€ classification.py
β”‚  β”‚  β”œβ”€β”€ gates.py
β”‚  β”‚  β”œβ”€β”€ interpretation.py
β”‚  β”‚  └── utils.py
β”‚  β”œβ”€β”€ train_base.py
β”‚  └── utils
β”‚  β”œβ”€β”€ distributions.py
β”‚  β”œβ”€β”€ getters_setters.py
β”‚  β”œβ”€β”€ metrics.py
β”‚  β”œβ”€β”€ optimizer.py
β”‚  └── plot.py
β”œβ”€β”€ experiments/

Training

To train a Vision DiffMask model on CIFAR-10 based on the Vision Transformer, use the following command:

python code/main.py --enable_progress_bar --num_epochs 20 --base_model ViT --dataset CIFAR10 \
 --from_pretrained tanlq/vit-base-patch16-224-in21k-finetuned-cifar10

You can refer to the next section for a full list of launch options.

Launch Arguments

Vision DiffMask

When training Vision DiffMask, the following launch options can be used:

Arguments:
 --enable_progress_bar
 Whether to enable the progress bar (NOT recommended when logging to file).
 --num_epochs NUM_EPOCHS
 Number of epochs to train.
 --seed SEED Random seed for reproducibility.
 --sample_images SAMPLE_IMAGES
 Number of images to sample for the mask callback.
 --log_every_n_steps LOG_EVERY_N_STEPS
 Number of steps between logging media & checkpoints.
 --base_model {ViT} Base model architecture to train.
 --from_pretrained FROM_PRETRAINED
 The name of the pretrained HF model to load.
 --dataset {MNIST,CIFAR10,CIFAR10_QA,toy}
 The dataset to use.
Vision DiffMask:
 --alpha ALPHA Initial value for the Lagrangian
 --lr LR Learning rate for DiffMask.
 --eps EPS KL divergence tolerance.
 --no_placeholder Whether to not use placeholder
 --lr_placeholder LR_PLACEHOLDER
 Learning for mask vectors.
 --lr_alpha LR_ALPHA Learning rate for lagrangian optimizer.
 --mul_activation MUL_ACTIVATION
 Value to multiply gate activations.
 --add_activation ADD_ACTIVATION
 Value to add to gate activations.
 --weighted_layer_distribution
 Whether to use a weighted distribution when picking a layer in DiffMask forward.
Data Modules:
 --data_dir DATA_DIR The directory where the data is stored.
 --batch_size BATCH_SIZE
 The batch size to use.
 --add_noise Use gaussian noise augmentation.
 --add_rotation Use rotation augmentation.
 --add_blur Use blur augmentation.
 --num_workers NUM_WORKERS
 Number of workers to use for data loading.
Visual QA:
 --class_idx CLASS_IDX
 The class (index) to count.
 --grid_size GRID_SIZE
 The number of images per row in the grid.
Training the base model

When training the base model (usually not needed as we support pretrained models from HuggingFace), the following launch options can be used:

Arguments:
 --checkpoint CHECKPOINT
 Checkpoint to resume the training from.
 --enable_progress_bar
 Whether to show progress bar during training. NOT recommended when logging to files.
 --num_epochs NUM_EPOCHS
 Number of epochs to train.
 --seed SEED Random seed for reproducibility.
 --base_model {ViT,ConvNeXt}
 Base model architecture to train.
 --from_pretrained FROM_PRETRAINED
 The name of the pretrained HF model to fine-tune from.
 --dataset {MNIST,CIFAR10,CIFAR10_QA,toy}
 The dataset to use.
Classification Model:
 --optimizer {AdamW,RAdam}
 The optimizer to use to train the model.
 --weight_decay WEIGHT_DECAY
 The optimizer's weight decay.
 --lr LR The initial learning rate for the model.
Data Modules:
 --data_dir DATA_DIR The directory where the data is stored.
 --batch_size BATCH_SIZE
 The batch size to use.
 --add_noise Use gaussian noise augmentation.
 --add_rotation Use rotation augmentation.
 --add_blur Use blur augmentation.
 --num_workers NUM_WORKERS
 Number of workers to use for data loading.
Visual QA:
 --class_idx CLASS_IDX
 The class (index) to count.
 --grid_size GRID_SIZE
 The number of images per row in the grid.
Evaluating the base model

When evaluating the base model, the following launch options can be used:

Arguments:
 --checkpoint CHECKPOINT
 Checkpoint to resume the training from.
 --enable_progress_bar
 Whether to show progress bar during training. NOT recommended when logging to files.
 --seed SEED Random seed for reproducibility.
 --base_model {ViT,ConvNeXt}
 Base model architecture to train.
 --from_pretrained FROM_PRETRAINED
 The name of the pretrained HF model to fine-tune from.
 --dataset {MNIST,CIFAR10,CIFAR10_QA,toy}
 The dataset to use.
Data Modules:
 --data_dir DATA_DIR The directory where the data is stored.
 --batch_size BATCH_SIZE
 The batch size to use.
 --add_noise Use gaussian noise augmentation.
 --add_rotation Use rotation augmentation.
 --add_blur Use blur augmentation.
 --num_workers NUM_WORKERS
 Number of workers to use for data loading.
Visual QA:
 --class_idx CLASS_IDX
 The class (index) to count.
 --grid_size GRID_SIZE
 The number of images per row in the grid.

Contributing

This project is licensed under the MIT license.

Acknowledgements

Vision DiffMask is an adaptation of DiffMask in the vision domain. Parts of the code are heavilty inspired from its original PyTorch implementation.

Citation

If you use this code or find our work otherwise useful, please consider citing our paper:

@inproceedings{nalmpantis2023vision,
 title={VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking},
 author={Nalmpantis, Angelos and Panagiotopoulos, Apostolos and Gkountouras, John and Papakostas, Konstantinos and Aziz, Wilker},
 booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
 pages={3755--3762},
 year={2023}
}

About

Official PyTorch implementation of Vision DiffMask, a post-hoc interpretation method for vision models.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

Contributors

Languages

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /