This repository provides the official PyTorch implementation of our NeurIPS 2022 paper:
Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models
Authors: Manli Shu, Weili Nie, De-An Huang, Tom Goldstein, Anima Anandkumar, Chaowei Xiao
For more details, please check out our (追記) project page (追記ここまで) and (追記) paper (追記ここまで).
This repository contains the implementation of TPT for image classification with a pre-trained CLIP. We consider 3 different initializations for test-time prompt tuning:
- Using a hand-crafted prompt as initialization (e.g., "a photo of a ___")
- Using a (追記) learned soft prompt (追記ここまで) (CoOp) as initialization.
- Using the output of a (追記) trained conditional prompt learner (追記ここまで) (CoCoOp) as initialization.
This implementation is for the single-GPU configuration.
To evaluate on ImageNet, ImageNet-V2, and ImageNet-Sketch (which has 1000 classes), you will need a GPU with more than (not including) 16GB memory. This codebase is tested on a GPU with 24GB memory. To evaluate other datasets (with less than a few hundred classes), a GPU with 16GB memory will work fine.
The code is tested on PyTorch 1.7.1.
We suggest downloading all datasets to a root directory (${data_root}), and renaming the directory of each dataset as suggested in ${ID_to_DIRNAME} in ./data/datautils.py. This would allow you to evaluate multiple datasets within the same run.
If this is not feasible, you could evaluate different datasets separately, and change the ${data_root} accordingly in the bash script.
For out-of-distribution generalization, we consider 5 datasets:
For cross-datasets generalization, we consider 10 datasets:
For cross-dataset generalization, we adopt the same train/val/test splits as CoOp. Please refer to this page, and look for download links of split_zhou_${dataset_name}.json, and put the json files under ./data/data_splits/.
We provide three bash scripts under ./scripts. You can modify the paths and other args in the scripts.
An example to run TPT with CoOp initialization on out-of-distribution datasets:
bash ./scripts/test_coop.sh I/A/V/R/K.
The command line arg ${testsets} can be multiple test datasets split by "/" (, which are stored under the same root dir ${data_root}).
Note that for simplicity, we use set_id to denote different datasets. A complete list of set_id can be found in ${ID_to_DIRNAME} in ./data/datautils.py.
| Method | ImageNet(IN) | IN-A | IN-V2 | IN-R | IN-Sketch | Average | OOD Average |
|---|---|---|---|---|---|---|---|
| CLIP-RN50 | 58.16 | 21.83 | 51.41 | 56.15 | 33.37 | 44.18 | 40.69 |
| Ensembled prompt | 59.81 | 23.24 | 52.91 | 60.72 | 35.48 | 46.43 | 43.09 |
| CoOp | (追記) 63.33 (追記ここまで) | 23.06 | 55.40 | 56.60 | 34.67 | 46.61 | 42.43 |
| CoCoOp | 62.81 | 23.32 | (追記) 55.72 (追記ここまで) | 57.74 | 34.48 | 46.81 | 42.82 |
| TPT (ours) | 60.74 | (追記) 26.67 (追記ここまで) | 54.7 | (追記) 59.11 (追記ここまで) | (追記) 35.09 (追記ここまで) | (追記) 47.26 (追記ここまで) | (追記) 43.89 (追記ここまで) |
| TPT + CoOp | 64.73 | 30.32 | 57.83 | 58.99 | 35.86 | 49.55 | 45.75 |
| TPT + CoCoOp | 62.93 | 27.40 | 56.60 | 59.88 | 35.43 | 48.45 | 44.83 |
In each matrix
Cross-dataset improvement normalized by the zero-shot baseline performance.
If you find our code useful or our work relevant, please consider citing:
@inproceedings{shu2022tpt,
author = {Manli, Shu and Weili, Nie and De-An, Huang and Zhiding, Yu and Tom, Goldstein and Anima, Anandkumar and Chaowei, Xiao},
title = {Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models},
booktitle = {NeurIPS},
year = {2022},
}
We thank the authors of CoOp/CoCoOp for their open-source implementation and instructions on data preparation.