--- license: apache-2.0 datasets: - hotpotqa/hotpot_qa - dgslibisey/MuSiQue - Aman279/Locomo - Phospheneser/DetectiveQA language: - en - zh metrics: - accuracy - exact_match - f1 - recall base_model: - Qwen/Qwen3-4B-Instruct-2507 pipeline_tag: text-ranking tags: - Rerank - Memory --- # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models. ## Model Description Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passageβquery relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals. Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision. ### Key Features - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision - **Selective Head Usage**: Focuses on top-performing QR attention heads - **Layer Truncation**: Only the first 25 of 36 layers are retained β all QR heads fall within layers 17β24, so deeper layers are unnecessary - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues ## Architecture This model is a **layer-truncated** version of Qwen3-4B-Instruct-2507. The original model has 36 transformer layers, but only the first **25 layers** are retained. The top-performing QR heads (layers 17β24) all fall within this range β deeper layers contribute no useful QR signal but consume extra computation and memory. Key design choices in `modeling_qwen3_qr.py`: - **`Qwen3ConfigGating`**: Extends `Qwen3Config` with `qr_start_layer`, `qr_end_layer`, `qr_head_list`, and `qr_head_list_mapped` (head indices remapped relative to `qr_start_layer`) - **Layer construction**: Only instantiates `qr_end_layer` (25) layers instead of all `num_hidden_layers` (36) - **No final norm**: Skips `self.norm(hidden_states)` since we only need intermediate KV/query caches, not the final hidden state - **`DynamicCacheWithQuery`**: Custom KV-cache that additionally stores query states at specified token positions during the forward pass ### Default Top-16 QR Heads ``` Layer-Head: 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18, 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31 ``` All selected heads fall within layers 17β24, which is why truncation to 25 layers is safe. ### Model Configuration | Parameter | Value | Description | |-----------|-------|-------------| | `qr_start_layer` | 17 | First layer containing QR heads | | `qr_end_layer` | 25 | Layers 0β24 are retained; layers 25β35 are removed | | `qr_head_list` | 16 (layer, head) pairs | Top QR heads using original layer indices | | `qr_head_list_mapped` | 16 (layer, head) pairs | QR heads with layer indices remapped relative to `qr_start_layer` | | `num_hidden_layers` | 36 | Original full model depth (config only, not instantiated) | | `num_attention_heads` | 32 | Attention heads per layer | | `num_key_value_heads` | 8 | GQA key-value heads per layer | ## Quick Start ### Loading the Model ```python import torch from transformers import AutoModel, AutoConfig, AutoTokenizer # Load model β trust_remote_code loads the layer-truncated Qwen3Model # and Qwen3ConfigGating automatically via auto_map in config.json config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) model = AutoModel.from_pretrained( "MindscapeRAG/QRRanker", config=config, torch_dtype=torch.float16, trust_remote_code=True, ).cuda().eval() tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker") ``` ### QR Score Computation After a forward pass, QR scores are computed from the cached query and key states: ```python import math def repeat_kv(hidden_states, n_rep): """Expand KV heads to match query heads (GQA).""" batch, num_kv_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) def get_attn_weights(key_states, query_states): """Compute softmax attention weights with causal mask.""" bsz, num_heads, q_len, head_dim = query_states.size() num_kv_heads = key_states.size(1) key_states = repeat_kv(key_states, num_heads // num_kv_heads) scale = 1.0 / math.sqrt(head_dim) attn_weights = torch.matmul(query_states * scale, key_states.transpose(2, 3)) # Causal mask seq_len = attn_weights.size(-1) causal_mask = torch.ones(num_heads, q_len, seq_len, device=attn_weights.device) causal_mask = torch.triu(causal_mask.transpose(-1, -2), diagonal=-(seq_len - q_len)).transpose(-1, -2) attn_weights += ((1 - causal_mask) * torch.finfo(attn_weights.dtype).min).unsqueeze(0) attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True) return torch.exp(attn_weights - attn_lses) def compute_qr_scores(query_cache, key_cache, qr_head_list, chunk_ranges, query_upper_bound): """ Compute QRRanker relevance scores for document chunks. Args: query_cache: List[Tensor] β query states per layer from DynamicCacheWithQuery key_cache: List[Tensor] β key states per layer qr_head_list: str β e.g. "20-15,21-11,17-27,..." chunk_ranges: List[[start, end]] β token ranges for each chunk query_upper_bound: int β upper bound of query token positions Returns: scores: Tensor of shape [num_chunks] """ all_head_scores = [] for key_state, query_state in zip(key_cache, query_cache): attn_weights = get_attn_weights(key_state[:, :, :query_upper_bound, :], query_state) attn_weights = attn_weights.mean(dim=-2) # average over query positions chunk_scores = torch.stack( [attn_weights[:, :, s:e].sum(dim=-1) for s, e in chunk_ranges], dim=2 ) all_head_scores.append(chunk_scores) # [batch, num_layers, num_heads, num_chunks] all_head_scores = torch.stack(all_head_scores, dim=1).float() # Select specific QR heads if qr_head_list is not None: head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')] indices = torch.tensor(head_set, device=all_head_scores.device) all_head_scores = all_head_scores[:, indices[:, 0], indices[:, 1], :] return all_head_scores.sum(dim=1).squeeze(0) ``` ### Complete Inference Pipeline ```python from custom_cache_new import DynamicCacheWithQuery def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device): """ Rerank candidate paragraphs by QRRanker relevance scores. Args: model: QRRanker model (loaded with trust_remote_code=True) tokenizer: Corresponding tokenizer question: Query string paragraphs: List of dicts with 'idx', 'title', 'paragraph_text' qr_head_list: str β e.g. "20-15,21-11,17-27,..." device: torch device Returns: ranked_ids: Paragraph indices sorted by descending relevance ranked_scores: Corresponding scores """ # Build input: [chunks] + [query] prompt_prefix = '<|im_start|>user\nHere are some retrieved chunks:\n\n' chunk_part = prompt_prefix chunk_ranges = [] for i, p in enumerate(paragraphs): text = p.get('title', '') + ': ' + p['paragraph_text'] chunk_part += f"[{i+1}]" start = len(chunk_part) chunk_part += ' ' + text.strip() end = len(chunk_part) chunk_ranges.append([start, end]) chunk_part += '\n\n' query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}" full_seq = chunk_part + query_part # Tokenize inputs = tokenizer(full_seq, max_length=262144, truncation=True, return_tensors='pt', return_offsets_mapping=True, add_special_tokens=False) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) offset_mapping = inputs['offset_mapping'][0] # Character-to-token mapping char_to_token = {} for i, (s, e) in enumerate(offset_mapping): for j in range(s, e): char_to_token[j] = i token_chunk_ranges = [ [char_to_token.get(s, 0), char_to_token.get(e - 1, 0) + 1] for s, e in chunk_ranges ] query_start = full_seq.index(question) query_positions = list(range( char_to_token[query_start], char_to_token[query_start + len(question) - 1] + 1 )) query_upper_bound = query_positions[-1] + 1 # Forward pass with torch.no_grad(): past_kv = DynamicCacheWithQuery(query_indices=query_positions) output = model(input_ids, attention_mask, past_key_values=past_kv) scores = compute_qr_scores( output.past_key_values.query_cache, output.past_key_values.key_cache, qr_head_list, token_chunk_ranges, query_upper_bound ) sorted_idx = torch.argsort(scores, descending=True).cpu().tolist() return [paragraphs[i]['idx'] for i in sorted_idx], [float(scores[i]) for i in sorted_idx] ``` ## Input Data Format ```json { "id": "sample_001", "question": "What is the capital of France?", "answer": "Paris", "paragraphs": [ { "idx": 0, "title": "France", "paragraph_text": "Paris is the capital and largest city of France...", "is_supporting": true } ], "summary": "Optional summary text..." } ``` | Field | Type | Required | Description | |-------|------|----------|-------------| | `id` | string | Yes | Unique sample identifier | | `question` | string | Yes | User query/question | | `answer` | string | No | Ground truth answer (for evaluation) | | `paragraphs` | list | Yes | List of candidate paragraphs | | `paragraphs[].idx` | int | Yes | Paragraph index | | `paragraphs[].title` | string | No | Paragraph title | | `paragraphs[].paragraph_text` | string | Yes | Paragraph content | | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) | | `summary` | string | No | Optional summary information | ## Environment | Package | Version | |---------|---------| | Python | 3.10 | | torch | 2.7.1 | | transformers | 4.53.0 | | flash-attn | (required for `flash_attention_2`) | | safetensors | 0.5.3 | | tokenizers | 0.21.2 | ```bash pip install torch==2.7.1 transformers==4.53.0 safetensors pip install flash-attn --no-build-isolation ``` ## Citation ```bibtex @misc{li2026queryfocusedmemoryawarererankerlong, title={Query-focused and Memory-aware Reranker for Long Context Processing}, author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou}, year={2026}, eprint={2602.12192}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2602.12192}, } ``` ## License This project is licensed under the Apache 2.0 License.