ModelCerebrasCerebraspublished Feb 26, 2025seen 5d

cerebras/Llama-3-CBHybridL-8B

Open original ↗

Captured source

source ↗
published Feb 26, 2025seen 5dcaptured 14hhttp 200method plaintask text-generationlicense otherparams 8Bdownloads 9likes 0

Llama-3-CBHybridL-8B: Model Information

We are excited to release the Cerebras hybrid dense/sparse attention versions of Llama-3.1-8B-Instruct models optimized for long-context performance. This series includes two models: Llama3.1-CBHybridL-8B (model with 25 sparse attention layers out of 32) and Llama3.1-CBHybridM-8B (28 sparse attention layers out of 32).

This model – Cerebras Llama3.1-CBHybridL-8B – was built on top of Llama-3.1-8B-Instruct using sparse attention training features available in Cerebras Model Zoo Release 2.4. We created hybrid versions of Llama-3.1-8B-Instruct with most of the self-attention layers fine-tuned to perform sparse lambda-mask attention which reduces KV cache memory usage by 1.6-1.7x while largely maintaining long-context performance.

You can find more information about Cerebras hybrid Llama models at the following locations:

Results

Our hybrid models retain most of their performance in long-context despite requiring much less memory for KV cache:

![HELMET result](./helmet_result.png)

| LongBench suite | Llama-3.1-8B-Instruct | Llama-3-CBHybridM-8B | Llama-3-CBHybridL-8B | |-----------------------|-----------------------|----------------------|----------------------| | KV cache memory*, GB | 2.147 | 1.275 | 1.376 | | Single-doc QA | 54.197 | 54.507 | 56.187 | | Multi-doc QA | 41.455 | 41.022 | 43.082 | | Summarization | 26.1275 | 25.607 | 25.357 | | Few-shot learning | 63.4075 | 64.42 | 65.183 | | Synthetic | 97.29 | 96.75 | 98.0 | | Code completion | 59.745 | 66.865 | 66.49 | | Macro-mean (EN & ZH) | 57.037 | 58.195 | 59.05 | | Macro-mean (EN) | 58.606 | 60.485 | 60.937 |

| HELMET suite (seq. len. 16K) | Llama-3.1-8B-Instruct | Llama-3-CBHybridM-8B | Llama-3-CBHybridL-8B | |-----------------------|----------------------|----------------------|----------------------| | KV cache memory, GB | 2.147 | 1.275 | 1.376 | | Recall | 99.6875 | 87.5625 | 95.1875 | | Rerank | 52.6671 | 42.7879 | 45.5175 | | RAG | 69.0417 | 68.625 | 69.4583 | | LongdocQA | 32.061 | 34.419 | 35.2879 | | ICL | 76 | 81.6 | 82.2 | | Summarization | 26.278 | 22.4353 | 23.7324 | | Macro-mean | 59.2892 | 56.2382 | 58.564 |

\* we include KV cache memory usage numbers at a representative sequence length of 16K, however note that samples across LongBench tasks have variable length, with ~14.5K being the 75th percentile of the sample length distribution.

Example Usage

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "cerebras/Llama-3-CBHybridL-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)

messages = [
{"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
{"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)

outputs = model.generate(
input_ids,
max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

Adding memory tokens for enhanced long-context performance

We found that adding auxiliary memory tokens to input sequences at regular intervals improves long-context performance. These tokens can be inserted into the input sequence using a helper tokenizer.insert_memory_tokens() method as shown below:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "cerebras/Llama-3-CBHybridL-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)

messages = [
{"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
{"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)

# Inserting 8 memory tokens per 256 tokens of original input:
input_ids = tokenizer.insert_memory_tokens(
input_ids,
episode_length=256,
num_memory_tokens_per_episode=8
)

outputs = model.generate(
input_ids,
max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

In our ablations, inserting 8 memory tokens after every 256 tokens of original input resulted in best accuracy. See out blog post for mode details.

License

Built with Llama3. Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.

Llama3.1 Community License

Acceptable Use Policy

Acknowledgements

Our models are fine-tuned versions of Meta-Llama-3.1-8B-Instruct. The sparse attention mechanism used in the Llama-3-CBHybrid model series is from the LM-Infinite work of Han et al. See our blog post for the full list of references.

Citing this work

@misc{cerebras2025cb-hybrid-llama,
author = {Lazarevich, Ivan and Hassanpour, Mohammad and Venkatesh, Ganesh},
title = {Compressing KV cache memory by half with sparse attention},
month = {March},
year = {2025},
howpublished = {\url{https://www.cerebras.ai/blog/compressing-kv-cache-memory-by-half-with-sparse-attention}}
}​​​​

Notability

notability 1.0/10

Very low downloads (8) in 30 days