This repository is the official implementation of the paper Trained Models Tell Us How to Make Them Robust to Group Shifts without Group Annotation. It contains code for running various experiments on different datasets to mitigate the problem of spurious correlations in deep learning models.
Our code requires Python 3.10 or higher to run successfully.
Please use either requirements.txt
with pip
or env.yml
with conda
to create envs and install dependencies.
The following datasets are supported: Waterbirds, CelebA, MultiNLI, Domino, Colored MNIST (CMNIST), CivilComments, and UrbanCars.
Please follow the instructions in the Data Access section to set up the datasets.
- ERM Training:
- Waterbirds example
python main.py --root_dir ./ --experiment ERM --dataset waterbirds --dataset_path /path/to/waterbird_complete95_forest2water2 --optimizer SGD -lr 1e-3 --step_size 100 --weight_decay 1e-4 --gamma 0.5 --epochs 300 --pretrained_path imagenet -b 128
- CivilComments and MultiNLI The ERM models were trained using the code provided by this repo.
- EVaLS-GL:
- Celeba example
python main.py --dataset domino --dataset_path /path/to/data/celeba --experiment loss --sample_size 5 -b 32 -lr 0.0005 --pretrained_path /path/to/resnet50.model --gamma 0.1 --weight_decay 0 --l1 0 --epochs 100 --optimizer adam --step_size 85 --seed 0 --feature_only True
- EVaLS:
Prior to running evals, you should run
EIIL.py
. Here is an example script:
python EIIL.py --dataset urbancars --dataset_path path/to/urbancars/noaug_features_seed0 --learning_rate 0.01 --num_steps 20000 --batch_size 128 --feature_only True --save_path path/to/save/urbancars/seed0/ --pretrained_path path/to/ckpts/urbancars/erm_seed0/ckpt.pth
This will create the new validation environments in the validation_path
.
- urbancars example
python3 main.py --dataset urbancars --dataset_path /path/to/data/urbancars/noaug_features_seed0 --experiment loss --sample_size 10 -b 32 -lr 0.0005 --pretrained_path /path/to/ckpts/urbancars/erm_seed0/ckpt.pth --gamma 0.1 --weight_decay 0 --l1 0 --epochs 100 --optimizer adam --step_size 85 --seed 0 --feature_only True --validation_path /path/to/validation_groups/urbancars/seed1 --for_free True
Note: If the --feature_only
flag is used, you should provide the pre-computed features of the specified dataset, which can be saved using the save_features.py
file in the repository. If the flag is not specified, the raw image or text files of the dataset should be provided. Here is an example script:
python3 save_features.py --dataset civilcomments --dataset_path path/to/data/civilcomments --save_path path/to/save/civilcomments/seed1/ --pretrained_path path/to/civilcomments/erm_seed1 --batch_size 64
Follow the instructions in the DFR repo to prepare the Waterbirds and CelebA datasets.
Our code expects the following files/folders in the [root_dir]/celebA
directory:
data/celeba_metadata.csv
data/img_align_celeba/
You can download these dataset files from this Kaggle link.
Make sure to move metadata/celeba_metadata.csv
, which contains last layer split, to the data directory.
Our code expects the following files/folders in the [root_dir]/
directory:
data/waterbird_complete95_forest2water2/
You can download a tarball of this dataset here.
Make sure to move metadata/wb_metadata.csv
, which contains last layer split, to the directory and rename it to metadata.csv
.
For the CivilComments dataset, we have altered the split column. The version with the last layer split can be downloaded from this link.
To run experiments on the MultiNLI dataset, please manually download and unzip the dataset from this link. Further, copy the utils_glue.py
to the root directory of the dataset, and add the metadata with the last layer split from metadata/multinli/metadata_random.csv
to the dataset directory.
You can create and access our modified dominoes dataset in notebooks/dominoes.ipynb
.
For the Urbancars dataset, please refer to Whac-A-Mole repo. As it is time-consuming and a bit challenging to create the whole dataset, we have uploaded the urbancars images here for the ease of access and usage.
To run an experiment, use the main.py
script with the appropriate arguments:
python main.py [--root_dir ROOT_DIR] [--learning_rate LEARNING_RATE] [--optimizer {adam,adamW,SGD}]
--experiment {ERM,DFR,loss,cluster,entropy,gradcam} --dataset {waterbirds,celeba,multinli,domino,cmnist,civilcomments,metashift,urbancars}
--dataset_path DATASET_PATH [--comments COMMENTS] [--output_path OUTPUT_PATH] [--bert_ckpt BERT_CKPT]
[--sample_size SAMPLE_SIZE] [--weight_decay WEIGHT_DECAY] [--l1 L1] [--step_size STEP_SIZE] [--gamma GAMMA]
[--epochs EPOCHS] [--model {ResNet,BERT}] [--pretrained_path PRETRAINED_PATH] [--batch_size BATCH_SIZE]
[--num_workers NUM_WORKERS] [--test_only TEST_ONLY] [--log LOG] [--for_free FOR_FREE] [--seed SEED]
[--random_grouping RANDOM_GROUPING] [--feature_only FEATURE_ONLY] [--num_val NUM_VAL]
[--fine_tune FINE_TUNE] [--early_stop_val EARLY_STOP_VAL] [--validation_path VALIDATION_PATH]
[--saved_val SAVED_VAL]
--root_dir
: Path to the root directory of the project (default:None
).--learning_rate
,-lr
: Learning rate for the optimizer (default:0.001
).--optimizer
: Type of optimizer (choices:adam
,adamW
,SGD
; default:adam
).--experiment
: Type of experiment (choices:ERM
,DFR
,loss
,cluster
,entropy
,gradcam
; required).loss
is equivalent to EVaLS.--dataset
: Name of the dataset (choices:waterbirds
,celeba
,multinli
,domino
,cmnist
,civilcomments
,metashift
,urbancars
; required).--dataset_path
: Path of the dataset (default:./waterbird_complete_forest2water2
).--comments
: Additional comments to be included in the log name (default:''
).--output_path
: Path for logs and checkpoints (default:/home/logs/
).--bert_ckpt
: Weights of pre-trained BERT for tokenization (default:bert-base-uncased
).--sample_size
: Sample size of each group in the experiment (default:64
).--weight_decay
: Weight decay coefficient for L2 regularization (default:0
).--l1
: Coefficient for L1 regularization (default:0
).--step_size
: Step size for the LR scheduler (default:10
).--gamma
: Gamma for the LR scheduler (default:0.1
).--epochs
: Number of epochs (default:30
).--model
: Name of the model to use (choices:ResNet
,BERT
; default:resnet
).--pretrained_path
: Path of the pre-trained model file (required for some experiments).--batch_size
,-b
: Batch size for the last layer re-training (default:128
).--num_workers
: Number of CPU cores to use (default:8
).--test_only
: Only test the specified model on the specified dataset and report WGA and Avg accuracy (default:False
).--log
: Whether to log the experiment on wandb (default:True
).--for_free
: Choose the best model based on group-inferred validation data- and not the ground-truth group annotations (default:False
).--seed
: Random seed (default:1
).--random_grouping
: Randomly group validation data (default:False
).--feature_only
: Load pre-computed features instead of raw data (default:False
).--num_val
: Number of validation sets (default:1
).--fine_tune
: Whether to fine-tune the classifier (default:False
).--early_stop_val
: Use early-stop models for validation grouping (default:False
).--validation_path
: Path to validation grouping models- inferred from EIIL (default:None
).--saved_val
: Use a saved validation set (default:False
).
This project welcomes contributions. To contribute, please follow these steps:
- Fork the repository
- Create a new branch
- Make your changes and commit them
- Push to the branch
- Create a new Pull Request
This project is licensed under the MIT License.
We would like to thank the authors of the following papers and repositories for their valuable contributions:
If you use this code in your research, please cite our paper:
@article{my-paper-title,
title={My Paper Title},
author={Author Names},
journal={Journal Name},
year={2023}
}