google-deepmind/distribution_shift_framework
Python
Captured source
source ↗google-deepmind/distribution_shift_framework
Description: This repository contains the code of the distribution shift framework presented in A Fine-Grained Analysis on Distribution Shift (Wiles et al., 2022).
Language: Python
License: Apache-2.0
Stars: 88
Forks: 9
Open issues: 15
Created: 2022-03-17T08:43:06Z
Pushed: 2026-05-19T23:52:54Z
Default branch: master
Fork: no
Archived: no
README:
Distribution Shift Framework
This repository contains the code of the distribution shift framework presented in A Fine-Grained Analysis on Distribution Shift (Wiles et al., 2022).
Contents
The framework allows to train models with different training methods on datasets undergoing specific kinds of distribution shift.
Training Methods
Currently the following training methods are supported (by setting the algorithm [config option](#config-options)):
- Empirical Risk Minimization (ERM, Vapnik, 1992)
- Invariant Risk Minimization (IRM, Arjovsky et al., 2019)
- Deep Correlation Alignment (Deep CORAL, Sun & Saenko, 2016)
- Domain-Adversarial Training of Neural Networks (DANN, Ganin et al., 2016)
- Style-Agnostic Networks (SagNet, Nam et al., 2021)
- (Batch Normalization Adaption (BN-Adapt, Schneider et al., 2020)
- Just Train Twice (JTT, Liu et al., 2021)
- Inter-domain Mixup (MixUp, Gulrajani & LopezPaz, 2021)
Model Architectures
The model [config option](#config-options) can be set to one of the following architectures
- ResNet18, ResNet50, ResNet101 (He et al., 2016)
- MLP (Vapnik, 1992)
Datasets
You can train on the following datasets (by setting the dataset_name [config option.](#config-options)):
- dSprites (Matthey et al., 2017)
- SmallNorb (LeCun et al., 2004)
- Shapes3D (Burgess & Kim, 2018)
Each dataset has a task (e.g. shape prediction on dSprites, set with the label [config option](#config-options)) and a set of properties (e.g. the colour of the shape in dSprites, set with the property_label [config option](#config-options)).
Distribution Shift Scenarios
You can evaluate your model on different conditions by varying the distribution of labels and properties in the configs. For each part of the distribution, you then assign a probability of sampling from that part of the distribution.
- Unseen data shift (
ood): Some parts of the distribution of the property
are unseen at training time (e.g. certain colours may be unseen in dSprites).
- Spurious correlation (
correlated): Some property is correlated with the
label at training time but not at test (e.g. all circles are red in training).
- Low data drift (
lowdata): Certain combinations of label and property are seen at a
a lower rate during training while they are uniformly distributed during test.
Additionally you can modify these scenarios with two conditions:
- Label noise (
noise): A certain percentage of the training labels are
corrupted.
- Fixed dataset size (
fixeddata): We reduce the total training dataset
size to a fixed amount.
These scenarios can be set through the test_case [config option.](#config-options)) with the keywords in parenthesis and an optional modifier separated by a full stop, e.g. lowdata.noise for low data drift with added label noise.
Future Additions
We plan to add additional methods, models and datasets from the paper as well as the raw results from all the experiments.
Usage Instructions
Installing
The following has been tested using Python 3.9.9.
For GPU support with JAX, edit requirements.txt before running run.sh (e.g., use jaxline==0.1.67+cuda111). See JAX's installation instructions for more details.
Execute run.sh to create and activate a virtualenv, install all necessary dependencies and run a test program to ensure that you can import all the modules.
# Run from the parent directory. sh distribution_shift_framework/run.sh
Running the Code
To train a model, use this virtualenv:
source /tmp/distribution_shift_framework/bin/activate
and then run
python3 -m distribution_shift_framework.classification.experiment \ --jaxline_mode=train \ --config=distribution_shift_framework/classification/config.py
For evaluation run
python3 -m distribution_shift_framework.classification.experiment \ --jaxline_mode=eval \ --config=distribution_shift_framework/classification/config.py
Config Options {#config-options}
Common changes can be done through an options string following the config file. The following options are available:
algorithm: What training method to use for training.model:: The model architecture to evaluate.dataset_name: The name of the dataset.test_case: Which of the distribution shift scenarios to set up.label: The label we're predicting.property_label: Which property is treated as in or out of
distribution (for the ood test_case), is correlated with the label (for the correlated setup) and is treated as having a low data region (for the low_data setup).
number_of_seeds: How many seeds to sweep over.batch_size: Batch size used for training and evaluation.training_steps: How many steps to train for.pretrained_checkpoint: Path to a checkpoint for a pretrained model.overwrite_image_size: Height and width to resize the images to. 0 means
no resizing.
eval_specific_ckpt: Path to a checkpoint for a one time evaluation.wids: Which wids of the checkpoint to look at.sweep_index: Which experiment from the sweep to run.use_fake_data: Whether…
Excerpt shown — open the source for the full document.