RepoAmazon (Nova)Amazon (Nova)published Sep 20, 2025seen 5d

amazon-science/Generative-vs-Discriminative-Classifiers

Python

Open original ↗

Captured source

source ↗

amazon-science/Generative-vs-Discriminative-Classifiers

Language: Python

License: NOASSERTION

Stars: 7

Forks: 0

Open issues: 0

Created: 2025-09-20T04:06:57Z

Pushed: 2025-11-26T04:06:15Z

Default branch: main

Fork: no

Archived: no

README:

Generative vs Discriminative Text Classification: A Comprehensive Comparison

> 🏆 Outstanding Paper Award at EMNLP 2025 > Award Announcement | arXiv Paper

This repository contains the official implementation for the paper "Generative or Discriminative? Revisiting Text Classification in the Era of Transformers" by Siva Rajesh Kasa et al.

📖 Abstract

The comparison between discriminative and generative classifiers has intrigued researchers since Efron's seminal analysis of logistic regression versus discriminant analysis. While early theoretical work established that generative classifiers exhibit lower sample complexity but higher asymptotic error in simple linear settings, these trade-offs remain unexplored in the transformer era. We present the first comprehensive evaluation of modern generative and discriminative architectures - Auto-regressive modeling, Masked Language Modeling, Discrete Diffusion, and Encoders for text classification.

🏗️ Repository Structure

├── README.md # This file
├── environment.yml # Shared conda environment for AR, AR-Pseudo, and Encoder/MLM
├── ar/ # Autoregressive classifier models
│ ├── train_gpt.py # Training script for GPT-based classifiers
│ └── infer_gpt.py # Inference script for GPT-based classifiers
├── ar_pseudo/ # Pseudo-autoregressive variant classifiers
│ ├── train_gpt.py # Training script for pseudo-AR classifiers
│ └── infer_gpt.py # Inference script for pseudo-AR classifiers
├── diff/ # Discrete diffusion classifier models
│ ├── README.md # Detailed documentation for diffusion models
│ ├── environment.yml # Conda environment for diffusion models
│ ├── run_train.py # Training script
│ ├── run_sample.py # Sampling script
│ ├── parallel_inference.py # Parallel inference for classification
│ ├── model/ # Model architectures
│ ├── configs/ # Configuration files
│ └── ... # Additional diffusion-related files
└── encoder_mlm/ # Encoder and MLM classifier models
├── mlm_classif_seed_fixed.py # Training script with fixed seeds
└── inference.py # Inference script

🚀 Quick Start

Automated Setup

Use our setup script for easy environment configuration:

# Check prerequisites and list approaches
python setup.py --check
python setup.py --list

# Setup your chosen approach
python setup.py --approach ar # Autoregressive models
python setup.py --approach diffusion # Diffusion models
python setup.py --approach encoder # Encoder models

Quick Demo

Run a quick demo to verify your setup:

# Automated comprehensive demo (recommended)
./examples/run_comprehensive_experiments.sh demo

Manual Installation

If you prefer manual setup, note that diffusion models require a separate conda environment, while AR, AR-Pseudo, and Encoder/MLM models share a single environment:

1. Shared Environment: AR, AR-Pseudo, and Encoder/MLM Models

# Create shared environment for AR, AR-Pseudo, and Encoder/MLM approaches
conda env create -f environment.yml
conda activate gendisc-transformers

2. Separate Environment: Discrete Diffusion Models

# Create separate environment for diffusion models
cd diff/
conda env create -f environment.yml
conda activate sedd

🔬 Experiments

Autoregressive Classification

Train GPT-based classifiers using generative modeling:

cd ar/
python train_gpt.py \
--data_key "SetFit/sst2" \
--ckpt_dir "./checkpoints/ar_sst2" \
--model_size "small" \
--max_epochs 50 \
--bsz 8

Key Parameters:

  • --data_key: Dataset identifier (e.g., "SetFit/sst2", "emotion", "ag_news")
  • --model_size: Model size ("small", "medium", "full")
  • --n_devices: Number of GPUs to use
  • --max_len: Maximum sequence length
  • --seed: Random seed for reproducibility

Discrete Diffusion Classification

Train discrete diffusion models for text classification:

cd diff/
# Single experiment with environment variables
DATASET_NAME="SetFit/sst2" TRAIN_SIZE="1024" N_ITERS="50000" python train.py model=small

# Or run comprehensive experiments across multiple datasets and sizes
./run_exps.sh

For inference:

python parallel_inference.py \
--model_path "path/to/trained/model" \
--dataset "ag_news" \
--batch_size 32

Encoder/MLM Classification

Run comprehensive experiments with BERT-based models:

cd encoder_mlm/
python mlm_classif_seed_fixed.py

This script runs experiments across:

  • Multiple datasets (emotion, sst2, ag_news, etc.)
  • Different model sizes (1 layer, 6 layers, 12 layers)
  • Various training sample sizes (128, 256, 512, 1024, 2048, 4096, full)
  • Multiple random seeds for statistical significance
  • Both MLM pretraining and direct classification approaches

📊 Supported Datasets

The repository supports various text classification datasets:

  • Sentiment Analysis: SST-2, SST-5, IMDb, Rotten Tomatoes
  • Topic Classification: AG News
  • Emotion Detection: Emotion dataset
  • Hate Speech Detection: Hate Speech Offensive
  • Multi-class Sentiment: Multi-class sentiment analysis
  • Financial News: Twitter Financial News Sentiment

🔧 Model Architectures

1. Autoregressive (AR) Models

  • GPT-2 based architecture
  • Generative approach: P(label|text) via likelihood estimation
  • Configurable model sizes (small, medium, full)

2. Pseudo-Autoregressive Models

  • Modified autoregressive approach
  • Hybrid generative-discriminative training

3. Discrete Diffusion Models

  • Score-based discrete diffusion
  • Novel application to text classification
  • Supports both uniform and absorbing noise schedules
  • Three model configurations available:
  • small: 1 layer, 1 attention head (1,1) - for quick experiments
  • medium: 6 layers, 6 attention heads (6,6) - balanced performance
  • large: 12 layers, 12 attention heads (12,12) - best performance

4. Encoder Models

  • BERT-based discriminative classifiers
  • Masked Language Model (MLM) pretraining option
  • Standard discriminative approach: direct classification head

📈 Key Findings

Our comprehensive evaluation reveals:

1. Sample Efficiency: Generative models show superior performance…

Excerpt shown — open the source for the full document.

Notability

notability 3.0/10

Low traction research repo