OpenBMB/DecT
Python
Captured source
source ↗OpenBMB/DecT
Description: Source code for ACL 2023 paper Decoder Tuning: Efficient Language Understanding as Decoding
Language: Python
License: MIT
Stars: 53
Forks: 8
Open issues: 0
Created: 2023-05-05T01:50:11Z
Pushed: 2023-06-25T08:57:48Z
Default branch: main
Fork: no
Archived: no
README:
DecT
Source code for ACL 2023 paper Decoder Tuning
Installation
Our code is based on PyTorch, HuggingFace Transformers, and OpenPrompt, please install dependencies by
pip install -r requirements.txt
Download Datasets
Download the 10 datasets with the following scripts
cd datasets bash download_datasets.sh cd ..
Run DecT
Then you can run DecT by running run_dect.py, for example
python src/run_dect.py \ --model roberta \ --size large \ --type mlm \ --model_name_or_path roberta-large \ --shot 1 \ --dataset sst2 \ --proto_dim 128 \ --model_logits_weight 1 \
In run_dect.py we provide instructions for each argument. To reproduce the results in paper, please run the following combinations
python src/run_dect.py \ --shot [1, 4, 16] \ --dataset [sst2, imdb, yelp, agnews, dbpedia, yahoo, rte, snli, mnli-m, mnli-mm, fewnerd] \ --seed [0, 1, 2, 3, 4] \
Configure Models
You can configure different models by setting model, type, size, model_name_or_path parameters.
model: Model name. We now support plms in OpenPrompt, LLaMA, Alpaca and Vicuna.type:mlm,lmorchat. This will determine the prompt template. Forlmtype models, we put the[mask]token at the end of the template. Forchatmodels, we implement the chat template for Vicuna v1.1. You may change the template if you use other models.size: Model size. Currently, it is used to set the hidden state dimension for LLaMA models.model_name_or_path: Path to model weights.
You can also modify the load_model function in src/run_dect.py to support more models!