amazon-science/TransitionFlowMatching
Python
Captured source
source ↗amazon-science/TransitionFlowMatching
Description: Official implementation of "Demystifying Transition Matching: When and Why It Can Beat Flow Matching" (AISTATS 2026). Code for image and video generation using Transition Matching.
Language: Python
License: Apache-2.0
Stars: 12
Forks: 0
Open issues: 2
Created: 2026-04-10T19:19:08Z
Pushed: 2026-04-14T01:38:03Z
Default branch: main
Fork: no
Archived: no
README:
Demystifying Transition Matching
This repository contains the official implementation of [Demystifying Transition Matching: When and Why It Can Beat Flow Matching](https://arxiv.org/abs/2510.17991), accepted to the Twenty-Ninth Annual Conference on Artificial Intelligence and Statistics (AISTATS), 2026.
Flow Matching (FM) underpins many state-of-the-art generative models, yet Transition Matching (TM) can achieve higher sample quality with fewer steps. We answer *when* and *why*: for unimodal Gaussian targets, TM attains strictly lower KL divergence than FM at any finite step count, because its stochastic latent updates preserve target covariance that deterministic FM underestimates. For Gaussian mixtures with well-separated modes, the distribution is approximately locally unimodal within each component, and TM retains this advantage — explaining its strong performance in multimodal settings.
These theoretical gains translate to practice. Across image and video generation benchmarks, TM consistently achieves better quality under the same or lower compute budgets, reaching competitive or superior performance with fewer sampling steps.
Class-Conditioned Image Generation
Frame-Conditioned Video Generation
For detailed theoretical analysis and additional experiments, see the full paper.
---
📋 Table of Contents
- [Demystifying Transition Matching](#demystifying-transition-matching)
- [📋 Table of Contents](#-table-of-contents)
- [📁 Project Structure](#-project-structure)
- [🔧 Setup](#-setup)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [🖼️ Image Generation (TM-Image)](#️-image-generation-tm-image)
- [Available Models](#available-models)
- [Dataset \& Pretrained Models](#dataset--pretrained-models)
- [💾 (Optional) Caching VAE Latents](#-optional-caching-vae-latents)
- [🏋️ Training](#️-training)
- [📊 Inference \& Evaluation](#-inference--evaluation)
- [🎬 Video Generation (TM-Video)](#-video-generation-tm-video)
- [Supported Datasets](#supported-datasets)
- [📡 (Optional) Connect to Weights \& Biases](#-optional-connect-to-weights--biases)
- [🏋️ Training](#️-training-1)
- [📊 Inference \& Evaluation](#-inference--evaluation-1)
- [📏 Evaluation Metrics](#-evaluation-metrics)
- [🙏 Acknowledgements](#-acknowledgements)
- [📄 Citations](#-citations)
- [🔒 Security](#-security)
- [⚖️ License](#️-license)
---
📁 Project Structure
TransitionFlowMatching/ ├── TM-Image/ # Image generation module │ ├── main_mar.py # Training & inference entry point │ ├── main_cache.py # VAE latent caching │ ├── engine_mar.py # Training & evaluation engine │ ├── models/ │ │ ├── mar.py # MAR backbone │ │ ├── dtm.py # Discrete Transition Matching │ │ ├── ar_backbone.py # Causal autoregressive backbone │ │ ├── diffloss.py # Diffusion head │ │ └── vae.py # VAE architecture │ ├── diffusion/ # Gaussian diffusion utilities │ ├── util/ # Data loading, LR scheduling, misc │ └── fid_stats/ # Pre-computed FID statistics │ ├── TM-Video/ # Video generation module │ ├── main.py # Hydra-based entry point │ ├── algorithms/ │ │ ├── dfot/ # Diffusion Forcing Transformer │ │ ├── vae/ # Video/image VAE │ │ └── common/ # Shared components & metrics │ ├── configurations/ # Hydra YAML configs │ │ ├── algorithm/ # Model configs (token_video, etc.) │ │ ├── dataset/ # Dataset configs (Kinetics-600, etc.) │ │ ├── experiment/ # Experiment configs │ │ └── shortcut/ # Pre-built configs (@DiT/token_XL) │ ├── datasets/ # Dataset implementations │ ├── experiments/ # Training/evaluation scripts │ └── utils/ # Checkpointing, W&B, distributed │ ├── requirements.txt ├── LICENSE ├── CODE_OF_CONDUCT.md └── CONTRIBUTING.md
---
🔧 Setup
Prerequisites
- Python 3.10
- CUDA-compatible GPU(s)
- Conda package manager
Installation
conda create python=3.10 -n tm conda activate tm pip install -r requirements.txt
---
🖼️ Image Generation (TM-Image)
Navigate to ./TM-Image for the image generation module, built on top of MAR (Masked Autoregressive Representation).
Available Models
| Model | Description | |-------|-------------| | mar_large | MAR backbone (large) | | dtm_large | Discrete Transition Matching (large) | | fm_large | Flow Matching baseline (large) |
Dataset & Pretrained Models
1. Download the ImageNet dataset and place it in your IMAGENET_PATH. 2. Download the pretrained VAE by running download_pretrained_vae() in download.py.
💾 (Optional) Caching VAE Latents
Since data augmentation only involves center cropping and random flipping, VAE latents can be pre-computed to speed up training:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_cache.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 \
--batch_size 128 \
--data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}🏋️ Training
export NODE_RANK=0
export MASTER_ADDR=
export MASTER_PORT=29500
torchrun --nproc_per_node=${N_GPU} --nnodes=${NNODES} \
--node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main_mar.py \
--img_size 256 \
--vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model ${MODEL} --diffloss_d ${DIFF_DEPTH} --diffloss_w ${DIFF_CH} \
--epochs ${EPOCHS} --warmup_epochs 100 --batch_size 128 --diffusion_batch_mul 1 \
--output_dir ${OUTPUT_DIR} \
--use_cached --cached_path ${CACHED_PATH} --save_last_freq 100 \
--blr 1.0e-4 --lr 0.0005 --T ${T}Key arguments:
| Argument | Description | Default | |----------|-------------|---------| | N_GPU | Number of GPUs per node | — | | NNODES | Number of nodes | — | | MODEL | Model type (mar_large, fm_large, dtm_large) | — | | DIFF_DEPTH | Diffusion head depth | 6 | | DIFF_CH | Diffusion head channel size | 1024 | | EPOCHS | Training epochs | 500 | | T | Discretized steps for TM | 128 |
📊…
Excerpt shown — open the source for the full document.
Notability
notability 3.0/10Low-star research repo from Amazon