amazon-science/Generative-vs-Discriminative-Classifiers
Python
Captured source
source ↗amazon-science/Generative-vs-Discriminative-Classifiers
Language: Python
License: NOASSERTION
Stars: 7
Forks: 0
Open issues: 0
Created: 2025-09-20T04:06:57Z
Pushed: 2025-11-26T04:06:15Z
Default branch: main
Fork: no
Archived: no
README:
Generative vs Discriminative Text Classification: A Comprehensive Comparison
> 🏆 Outstanding Paper Award at EMNLP 2025 > Award Announcement | arXiv Paper
This repository contains the official implementation for the paper "Generative or Discriminative? Revisiting Text Classification in the Era of Transformers" by Siva Rajesh Kasa et al.
📖 Abstract
The comparison between discriminative and generative classifiers has intrigued researchers since Efron's seminal analysis of logistic regression versus discriminant analysis. While early theoretical work established that generative classifiers exhibit lower sample complexity but higher asymptotic error in simple linear settings, these trade-offs remain unexplored in the transformer era. We present the first comprehensive evaluation of modern generative and discriminative architectures - Auto-regressive modeling, Masked Language Modeling, Discrete Diffusion, and Encoders for text classification.
🏗️ Repository Structure
├── README.md # This file ├── environment.yml # Shared conda environment for AR, AR-Pseudo, and Encoder/MLM ├── ar/ # Autoregressive classifier models │ ├── train_gpt.py # Training script for GPT-based classifiers │ └── infer_gpt.py # Inference script for GPT-based classifiers ├── ar_pseudo/ # Pseudo-autoregressive variant classifiers │ ├── train_gpt.py # Training script for pseudo-AR classifiers │ └── infer_gpt.py # Inference script for pseudo-AR classifiers ├── diff/ # Discrete diffusion classifier models │ ├── README.md # Detailed documentation for diffusion models │ ├── environment.yml # Conda environment for diffusion models │ ├── run_train.py # Training script │ ├── run_sample.py # Sampling script │ ├── parallel_inference.py # Parallel inference for classification │ ├── model/ # Model architectures │ ├── configs/ # Configuration files │ └── ... # Additional diffusion-related files └── encoder_mlm/ # Encoder and MLM classifier models ├── mlm_classif_seed_fixed.py # Training script with fixed seeds └── inference.py # Inference script
🚀 Quick Start
Automated Setup
Use our setup script for easy environment configuration:
# Check prerequisites and list approaches python setup.py --check python setup.py --list # Setup your chosen approach python setup.py --approach ar # Autoregressive models python setup.py --approach diffusion # Diffusion models python setup.py --approach encoder # Encoder models
Quick Demo
Run a quick demo to verify your setup:
# Automated comprehensive demo (recommended) ./examples/run_comprehensive_experiments.sh demo
Manual Installation
If you prefer manual setup, note that diffusion models require a separate conda environment, while AR, AR-Pseudo, and Encoder/MLM models share a single environment:
1. Shared Environment: AR, AR-Pseudo, and Encoder/MLM Models
# Create shared environment for AR, AR-Pseudo, and Encoder/MLM approaches conda env create -f environment.yml conda activate gendisc-transformers
2. Separate Environment: Discrete Diffusion Models
# Create separate environment for diffusion models cd diff/ conda env create -f environment.yml conda activate sedd
🔬 Experiments
Autoregressive Classification
Train GPT-based classifiers using generative modeling:
cd ar/ python train_gpt.py \ --data_key "SetFit/sst2" \ --ckpt_dir "./checkpoints/ar_sst2" \ --model_size "small" \ --max_epochs 50 \ --bsz 8
Key Parameters:
--data_key: Dataset identifier (e.g., "SetFit/sst2", "emotion", "ag_news")--model_size: Model size ("small", "medium", "full")--n_devices: Number of GPUs to use--max_len: Maximum sequence length--seed: Random seed for reproducibility
Discrete Diffusion Classification
Train discrete diffusion models for text classification:
cd diff/ # Single experiment with environment variables DATASET_NAME="SetFit/sst2" TRAIN_SIZE="1024" N_ITERS="50000" python train.py model=small # Or run comprehensive experiments across multiple datasets and sizes ./run_exps.sh
For inference:
python parallel_inference.py \ --model_path "path/to/trained/model" \ --dataset "ag_news" \ --batch_size 32
Encoder/MLM Classification
Run comprehensive experiments with BERT-based models:
cd encoder_mlm/ python mlm_classif_seed_fixed.py
This script runs experiments across:
- Multiple datasets (emotion, sst2, ag_news, etc.)
- Different model sizes (1 layer, 6 layers, 12 layers)
- Various training sample sizes (128, 256, 512, 1024, 2048, 4096, full)
- Multiple random seeds for statistical significance
- Both MLM pretraining and direct classification approaches
📊 Supported Datasets
The repository supports various text classification datasets:
- Sentiment Analysis: SST-2, SST-5, IMDb, Rotten Tomatoes
- Topic Classification: AG News
- Emotion Detection: Emotion dataset
- Hate Speech Detection: Hate Speech Offensive
- Multi-class Sentiment: Multi-class sentiment analysis
- Financial News: Twitter Financial News Sentiment
🔧 Model Architectures
1. Autoregressive (AR) Models
- GPT-2 based architecture
- Generative approach: P(label|text) via likelihood estimation
- Configurable model sizes (small, medium, full)
2. Pseudo-Autoregressive Models
- Modified autoregressive approach
- Hybrid generative-discriminative training
3. Discrete Diffusion Models
- Score-based discrete diffusion
- Novel application to text classification
- Supports both uniform and absorbing noise schedules
- Three model configurations available:
- small: 1 layer, 1 attention head (1,1) - for quick experiments
- medium: 6 layers, 6 attention heads (6,6) - balanced performance
- large: 12 layers, 12 attention heads (12,12) - best performance
4. Encoder Models
- BERT-based discriminative classifiers
- Masked Language Model (MLM) pretraining option
- Standard discriminative approach: direct classification head
📈 Key Findings
Our comprehensive evaluation reveals:
1. Sample Efficiency: Generative models show superior performance…
Excerpt shown — open the source for the full document.
Notability
notability 3.0/10Low traction research repo