llama-3.2-3b-attn-drop-3
Model Description
This model is a surgically optimized version of meta-llama/Llama-3.2-3B, created as part of Chapter 8 in the book "Rearchitecting LLMs".
- Book: Rearchitecting LLMs
- Framework: OptiPFair
- Technique: Attention Optimization (Physical Attention Layer Removal)
- Chapter: Chapter 8 - Attention Optimization
- Notebook: CH08_NB03_Remove_Attention_
- Paper: What Matters in Transformers? Not All Attention is Needed β He et al., 2024
Implementation
How It Works
Unlike KV cache quantization, this technique permanently removes the least important attention modules from the model architecture. The result is a smaller model β fewer parameters, less VRAM, faster inference β that can be saved and reloaded without any custom inference engine.
The importance of each attention sublayer is measured via cosine similarity between its input and output, including the residual connection. A layer whose output is nearly identical to its input is doing very little work and is a strong candidate for removal.
The metric, as defined in the paper, is:
S = 1 - CosineSim(X_A, Y_A)
Where X_A is the hidden state arriving at the attention sublayer (captured before input_layernorm) and Y_A is the hidden state after the attention computation and its residual (X_A + Attention(LayerNorm(X_A))). Including the residual in the measurement prevents underestimating layers that rely on it to preserve information flow.
Calibration Data
Importance scores are computed over 400 samples from Cosmopedia, weighted to cover the same range of tasks as the evaluation benchmarks:
| Subset | Weight | Samples | Activates |
|---|---|---|---|
| stories | 0.300 | 120 | Contextual reasoning (hellaswag, winogrande) |
| web_samples_v2 | 0.200 | 80 | Contextual reasoning |
| web_samples_v1 | 0.150 | 60 | Contextual reasoning |
| wikihow | 0.150 | 60 | Procedural understanding (piqa) |
| openstax | 0.125 | 50 | Academic reasoning (arc_easy) |
| stanford | 0.075 | 30 | Academic reasoning |
Measuring Layer Importance
Hooks are anchored at two points inside each LlamaDecoderLayer:
def setup_attention_hooks(model):
hooks = []
for i, layer in enumerate(model.model.layers):
# X_A: hidden state before attention transformation
hooks.append(layer.input_layernorm.register_forward_pre_hook(make_input_hook(i)))
# Y_A: hidden state after attention + residual
hooks.append(layer.post_attention_layernorm.register_forward_pre_hook(make_output_hook(i)))
return hooks, ...
The resulting importance scores for Llama-3.2-3B (ascending β lowest = most redundant):
Layer Score
------ ----------
21 0.006152 β dropped
18 0.006841 β dropped
22 0.007109 β dropped
2 0.007285
...
1 0.110186
0 0.278848 β most important, never touched
Layers 0 and 1 are clearly the most active and are never candidates for removal. The intermediate region (roughly layers 2β25) clusters near zero. Layers 26 and 27 recover, consistent with their role in consolidating representations before the language model head.
Physical Deletion
The 3 selected layers (18, 21, 22) are physically removed. For each layer, both self_attn and input_layernorm are deleted and the layer's forward() is patched to route hidden states directly to the MLP block:
def drop_attention_layer(model, layer_idx):
layer = model.model.layers[layer_idx]
# Release attention parameters from memory
delattr(layer, "self_attn")
delattr(layer, "input_layernorm")
def forward_no_attn(self, hidden_states, ...) -> Tuple[torch.Tensor]:
# Attention block is skipped β hidden states flow through unchanged
# MLP block: normalize, transform, add residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
PrunedLayer = type("PrunedDecoderLayer", (type(layer),), {"forward": forward_no_attn})
layer.__class__ = PrunedLayer
The dropped layer indices are stored in config.json so the model can be reloaded correctly:
"dropped_attn_layers": [18, 21, 22]
Loading the Model
The model uses a custom class (PrunedLlamaForCausalLM) that re-applies the attention bypass automatically at load time. A single trust_remote_code=True flag is all that's needed:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"oopere/llama-3.2-3b-attn-drop-3",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("oopere/llama-3.2-3b-attn-drop-3")
Note: This model was developed and tested on a Google Colab T4 GPU (free tier). Use
torch.float16on T4; switch totorch.bfloat16on Ampere-class GPUs or newer (L4, A100, etc.).
Benchmarks
The following table compares the pruned model against the baseline (meta-llama/Llama-3.2-3B) on general knowledge benchmarks (0-shot):
| Benchmark | Metric | Baseline | Pruned | Ξ |
|---|---|---|---|---|
| ARC Easy | acc_norm | 0.7180 | 0.7104 | -1.06% |
| HellaSwag | acc_norm | 0.7405 | 0.7274 | -1.77% |
| LAMBADA OpenAI | accuracy | 0.6969 | 0.6189 | -11.19% |
| PIQA | acc_norm | 0.7813 | 0.7693 | -1.54% |
| WinoGrande | accuracy | 0.6961 | 0.6843 | -1.70% |
What the numbers mean
Removing 3 attention modules eliminates approximately 75M parameters (~150MB in FP16), representing a ~10.7% reduction in dynamic memory growth per generated token.
Four of the five benchmarks show contained degradation of 1β2%. ARC-Easy, HellaSwag, PIQA, and WinoGrande are all multiple-choice or binary-selection tasks: the model receives a context and must pick the correct continuation from a closed set of alternatives. This type of task primarily activates factual reasoning and local semantic coherence. A 1β2% drop is consistent with having removed layers that the importance scores already identified as redundant β the model is still reasoning.
LAMBADA is a different story. This benchmark requires predicting the last word of a paragraph from the full preceding context β a task that demands generative coherence across several sentences. The 11.2 percentage point drop contrasts sharply with the behavior on all other evaluations.
The repetition problem
The difference becomes immediately visible in free-text generation:
Baseline:
France and the largest city in the country. It is located in the north-central part of the
country, on the banks of the Seine River. The city is known for its rich history, culture,
and architecture.
Pruned model:
France. It is the largest city in France and the second largest city in Europe. The city is
located on the River Seine in the north of France. Paris is the most visited city in the
world. It is the most visited city in the world. It is the most visited city in the world.
The pruned model falls into a repetition loop it cannot exit. The same pattern was reproduced with Llama-3.1-8B.
The reason is structural: attention modules do not only transform token information β they also participate in suppressing the probability of recently generated tokens. When several are removed, that suppression weakens and the model tends to reinforce the most probable next token instead of advancing. LAMBADA captures exactly this limitation, which explains why its drop is ten times larger than the rest of the benchmarks.
When to use this model
This technique is viable when outputs are short and structured β question answering, classification, information extraction, or any task where responses are a few tokens long. It is problematic for open-ended generation, where repetition loops degrade output quality significantly.
β οΈ This is not a neutral optimization. It carries a meaningful trade-off that is worth understanding before applying it in practice.
Intended Use
This model is intended as a learning artifact for readers of Rearchitecting LLMs (Chapter 8 Hands-On Lab). It demonstrates that a non-trivial fraction of attention layers in a modern LLM can be physically removed with surprisingly small benchmark loss β and that the real cost only becomes visible in free-text generation. No custom inference stack is required to load or run it.
It is not intended for production use.
- Downloads last month
- 130
Model tree for oopere/llama-3.2-3b-attn-drop-3
Base model
meta-llama/Llama-3.2-3B