llama-3.2-3b-attn-drop-3

Model Description

This model is a surgically optimized version of meta-llama/Llama-3.2-3B, created as part of Chapter 8 in the book "Rearchitecting LLMs".

linkedin-profile-banner-martra


Implementation

How It Works

Unlike KV cache quantization, this technique permanently removes the least important attention modules from the model architecture. The result is a smaller model β€” fewer parameters, less VRAM, faster inference β€” that can be saved and reloaded without any custom inference engine.

The importance of each attention sublayer is measured via cosine similarity between its input and output, including the residual connection. A layer whose output is nearly identical to its input is doing very little work and is a strong candidate for removal.

The metric, as defined in the paper, is:

S = 1 - CosineSim(X_A, Y_A)

Where X_A is the hidden state arriving at the attention sublayer (captured before input_layernorm) and Y_A is the hidden state after the attention computation and its residual (X_A + Attention(LayerNorm(X_A))). Including the residual in the measurement prevents underestimating layers that rely on it to preserve information flow.

Calibration Data

Importance scores are computed over 400 samples from Cosmopedia, weighted to cover the same range of tasks as the evaluation benchmarks:

Subset Weight Samples Activates
stories 0.300 120 Contextual reasoning (hellaswag, winogrande)
web_samples_v2 0.200 80 Contextual reasoning
web_samples_v1 0.150 60 Contextual reasoning
wikihow 0.150 60 Procedural understanding (piqa)
openstax 0.125 50 Academic reasoning (arc_easy)
stanford 0.075 30 Academic reasoning

Measuring Layer Importance

Hooks are anchored at two points inside each LlamaDecoderLayer:

def setup_attention_hooks(model):
    hooks = []
    for i, layer in enumerate(model.model.layers):
        # X_A: hidden state before attention transformation
        hooks.append(layer.input_layernorm.register_forward_pre_hook(make_input_hook(i)))
        # Y_A: hidden state after attention + residual
        hooks.append(layer.post_attention_layernorm.register_forward_pre_hook(make_output_hook(i)))
    return hooks, ...

The resulting importance scores for Llama-3.2-3B (ascending β€” lowest = most redundant):

Layer   Score
------  ----------
  21    0.006152   ← dropped
  18    0.006841   ← dropped
  22    0.007109   ← dropped
   2    0.007285
  ...
   1    0.110186
   0    0.278848   ← most important, never touched

Layers 0 and 1 are clearly the most active and are never candidates for removal. The intermediate region (roughly layers 2–25) clusters near zero. Layers 26 and 27 recover, consistent with their role in consolidating representations before the language model head.

Physical Deletion

The 3 selected layers (18, 21, 22) are physically removed. For each layer, both self_attn and input_layernorm are deleted and the layer's forward() is patched to route hidden states directly to the MLP block:

def drop_attention_layer(model, layer_idx):
    layer = model.model.layers[layer_idx]

    # Release attention parameters from memory
    delattr(layer, "self_attn")
    delattr(layer, "input_layernorm")

    def forward_no_attn(self, hidden_states, ...) -> Tuple[torch.Tensor]:
        # Attention block is skipped β€” hidden states flow through unchanged
        # MLP block: normalize, transform, add residual
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

    PrunedLayer = type("PrunedDecoderLayer", (type(layer),), {"forward": forward_no_attn})
    layer.__class__ = PrunedLayer

The dropped layer indices are stored in config.json so the model can be reloaded correctly:

"dropped_attn_layers": [18, 21, 22]

Loading the Model

The model uses a custom class (PrunedLlamaForCausalLM) that re-applies the attention bypass automatically at load time. A single trust_remote_code=True flag is all that's needed:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "oopere/llama-3.2-3b-attn-drop-3",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("oopere/llama-3.2-3b-attn-drop-3")

Note: This model was developed and tested on a Google Colab T4 GPU (free tier). Use torch.float16 on T4; switch to torch.bfloat16 on Ampere-class GPUs or newer (L4, A100, etc.).


Benchmarks

The following table compares the pruned model against the baseline (meta-llama/Llama-3.2-3B) on general knowledge benchmarks (0-shot):

Benchmark Metric Baseline Pruned Ξ”
ARC Easy acc_norm 0.7180 0.7104 -1.06%
HellaSwag acc_norm 0.7405 0.7274 -1.77%
LAMBADA OpenAI accuracy 0.6969 0.6189 -11.19%
PIQA acc_norm 0.7813 0.7693 -1.54%
WinoGrande accuracy 0.6961 0.6843 -1.70%

What the numbers mean

Removing 3 attention modules eliminates approximately 75M parameters (~150MB in FP16), representing a ~10.7% reduction in dynamic memory growth per generated token.

Four of the five benchmarks show contained degradation of 1–2%. ARC-Easy, HellaSwag, PIQA, and WinoGrande are all multiple-choice or binary-selection tasks: the model receives a context and must pick the correct continuation from a closed set of alternatives. This type of task primarily activates factual reasoning and local semantic coherence. A 1–2% drop is consistent with having removed layers that the importance scores already identified as redundant β€” the model is still reasoning.

LAMBADA is a different story. This benchmark requires predicting the last word of a paragraph from the full preceding context β€” a task that demands generative coherence across several sentences. The 11.2 percentage point drop contrasts sharply with the behavior on all other evaluations.

The repetition problem

The difference becomes immediately visible in free-text generation:

Baseline:

France and the largest city in the country. It is located in the north-central part of the
country, on the banks of the Seine River. The city is known for its rich history, culture,
and architecture.

Pruned model:

France. It is the largest city in France and the second largest city in Europe. The city is
located on the River Seine in the north of France. Paris is the most visited city in the
world. It is the most visited city in the world. It is the most visited city in the world.

The pruned model falls into a repetition loop it cannot exit. The same pattern was reproduced with Llama-3.1-8B.

The reason is structural: attention modules do not only transform token information β€” they also participate in suppressing the probability of recently generated tokens. When several are removed, that suppression weakens and the model tends to reinforce the most probable next token instead of advancing. LAMBADA captures exactly this limitation, which explains why its drop is ten times larger than the rest of the benchmarks.

When to use this model

This technique is viable when outputs are short and structured β€” question answering, classification, information extraction, or any task where responses are a few tokens long. It is problematic for open-ended generation, where repetition loops degrade output quality significantly.

⚠️ This is not a neutral optimization. It carries a meaningful trade-off that is worth understanding before applying it in practice.


Intended Use

This model is intended as a learning artifact for readers of Rearchitecting LLMs (Chapter 8 Hands-On Lab). It demonstrates that a non-trivial fraction of attention layers in a modern LLM can be physically removed with surprisingly small benchmark loss β€” and that the real cost only becomes visible in free-text generation. No custom inference stack is required to load or run it.

It is not intended for production use.

Downloads last month
130
Safetensors
Model size
3B params
Tensor type
F16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for oopere/llama-3.2-3b-attn-drop-3

Finetuned
(432)
this model

Dataset used to train oopere/llama-3.2-3b-attn-drop-3

Collection including oopere/llama-3.2-3b-attn-drop-3

Paper for oopere/llama-3.2-3b-attn-drop-3