Скарты

#27
by Mrdips - opened
Files changed (2) hide show
  1. README.md +0 -2
  2. inference/model.py +16 -27
README.md CHANGED
@@ -79,9 +79,7 @@ This experimental release represents our ongoing research into more efficient tr
79
  | SWE-bench Multilingual | 57.8 | 57.9 |
80
  | Terminal-bench | 36.7 | 37.7 |
81
 
82
- ## Update
83
 
84
- - 2025.11.17: **We have identified that previous versions of the inference demo code contained an implementation discrepancy in Rotary Position Embedding (RoPE) within the indexer module, potentially leading to degraded model performance.** Specifically, the input tensor to RoPE in the indexer module requires a non-interleaved layout, whereas RoPE in the MLA module expects an interleaved layout. This issue has now been resolved. Please refer to the updated version of the inference demo code and take note of this implementation detail.
85
 
86
  ## How to Run Locally
87
 
 
79
  | SWE-bench Multilingual | 57.8 | 57.9 |
80
  | Terminal-bench | 36.7 | 37.7 |
81
 
 
82
 
 
83
 
84
  ## How to Run Locally
85
 
inference/model.py CHANGED
@@ -2,6 +2,7 @@ import math
2
  from dataclasses import dataclass
3
  from typing import Tuple, Optional, Literal
4
 
 
5
  import torch
6
  from torch import nn
7
  import torch.nn.functional as F
@@ -402,7 +403,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
402
  return freqs_cis
403
 
404
 
405
- def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
406
  """
407
  Applies rotary positional embeddings to the input tensor.
408
 
@@ -414,14 +415,9 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool
414
  torch.Tensor: Tensor with rotary embeddings applied.
415
  """
416
  dtype = x.dtype
417
- shape = x.shape
418
- if not interleaved:
419
- x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
420
- x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
421
  freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
422
  y = torch.view_as_real(x * freqs_cis).flatten(3)
423
- if not interleaved:
424
- y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
425
  return y.to(dtype)
426
 
427
 
@@ -445,8 +441,7 @@ class Indexer(torch.nn.Module):
445
  self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
446
  self.wk = Linear(self.dim, self.head_dim)
447
  self.k_norm = LayerNorm(self.head_dim)
448
- # weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
449
- self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
450
  self.softmax_scale = self.head_dim ** -0.5
451
  self.scale_fmt = args.scale_fmt
452
 
@@ -458,16 +453,14 @@ class Indexer(torch.nn.Module):
458
  bsz, seqlen, _ = x.size()
459
  end_pos = start_pos + seqlen
460
  q = self.wq_b(qr)
461
- q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
462
  q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
463
- # rope in indexer is not interleaved
464
- q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
465
  q = torch.cat([q_pe, q_nope], dim=-1)
466
  k = self.wk(x)
467
  k = self.k_norm(k)
468
  k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
469
- # rope in indexer is not interleaved
470
- k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
471
  k = torch.cat([k_pe, k_nope], dim=-1)
472
  q = rotate_activation(q)
473
  k = rotate_activation(k)
@@ -475,7 +468,7 @@ class Indexer(torch.nn.Module):
475
  k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
476
  self.k_cache[:bsz, start_pos:end_pos] = k_fp8
477
  self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
478
- weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
479
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
480
  index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
481
  if mask is not None:
@@ -531,7 +524,6 @@ class MLA(nn.Module):
531
  self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
532
  self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
533
  self.softmax_scale = self.qk_head_dim ** -0.5
534
- self.scale_fmt = args.scale_fmt
535
  if args.max_seq_len > args.original_seq_len:
536
  mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
537
  self.softmax_scale = self.softmax_scale * mscale * mscale
@@ -566,9 +558,6 @@ class MLA(nn.Module):
566
  kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
567
  kv = self.kv_norm(kv)
568
  k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
569
- # we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
570
- kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
571
- kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
572
  self.kv_cache[:bsz, start_pos:end_pos] = kv
573
  self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
574
  if mask is not None: # MHA prefill
@@ -577,7 +566,7 @@ class MLA(nn.Module):
577
  kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
578
  k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
579
  k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
580
- scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
581
 
582
  # indexer
583
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
@@ -585,24 +574,24 @@ class MLA(nn.Module):
585
  index_mask += mask
586
  scores += index_mask.unsqueeze(2)
587
 
588
- scores = scores.softmax(dim=-1)
589
- x = torch.einsum("bsht,bthd->bshd", scores, v)
590
- else: # MQA decode
591
  if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
592
  self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
593
  wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
594
  wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
595
  q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
596
- scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
597
- torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
598
 
599
  # indexer
600
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
601
  index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
602
  scores += index_mask.unsqueeze(2)
603
 
604
- scores = scores.softmax(dim=-1)
605
- x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
606
  x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
607
  x = self.wo(x.flatten(2))
608
  return x
 
2
  from dataclasses import dataclass
3
  from typing import Tuple, Optional, Literal
4
 
5
+ from einops import rearrange
6
  import torch
7
  from torch import nn
8
  import torch.nn.functional as F
 
403
  return freqs_cis
404
 
405
 
406
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
407
  """
408
  Applies rotary positional embeddings to the input tensor.
409
 
 
415
  torch.Tensor: Tensor with rotary embeddings applied.
416
  """
417
  dtype = x.dtype
418
+ x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
 
 
 
419
  freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
420
  y = torch.view_as_real(x * freqs_cis).flatten(3)
 
 
421
  return y.to(dtype)
422
 
423
 
 
441
  self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
442
  self.wk = Linear(self.dim, self.head_dim)
443
  self.k_norm = LayerNorm(self.head_dim)
444
+ self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
 
445
  self.softmax_scale = self.head_dim ** -0.5
446
  self.scale_fmt = args.scale_fmt
447
 
 
453
  bsz, seqlen, _ = x.size()
454
  end_pos = start_pos + seqlen
455
  q = self.wq_b(qr)
456
+ q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
457
  q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
458
+ q_pe = apply_rotary_emb(q_pe, freqs_cis)
 
459
  q = torch.cat([q_pe, q_nope], dim=-1)
460
  k = self.wk(x)
461
  k = self.k_norm(k)
462
  k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
463
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
 
464
  k = torch.cat([k_pe, k_nope], dim=-1)
465
  q = rotate_activation(q)
466
  k = rotate_activation(k)
 
468
  k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
469
  self.k_cache[:bsz, start_pos:end_pos] = k_fp8
470
  self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
471
+ weights = self.weights_proj(x) * self.n_heads ** -0.5
472
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
473
  index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
474
  if mask is not None:
 
524
  self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
525
  self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
526
  self.softmax_scale = self.qk_head_dim ** -0.5
 
527
  if args.max_seq_len > args.original_seq_len:
528
  mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
529
  self.softmax_scale = self.softmax_scale * mscale * mscale
 
558
  kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
559
  kv = self.kv_norm(kv)
560
  k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
 
 
 
561
  self.kv_cache[:bsz, start_pos:end_pos] = kv
562
  self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
563
  if mask is not None: # MHA prefill
 
566
  kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
567
  k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
568
  k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
569
+ scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
570
 
571
  # indexer
572
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
 
574
  index_mask += mask
575
  scores += index_mask.unsqueeze(2)
576
 
577
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
578
+ x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
579
+ else: # MHA decode
580
  if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
581
  self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
582
  wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
583
  wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
584
  q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
585
+ scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) +
586
+ torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
587
 
588
  # indexer
589
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
590
  index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
591
  scores += index_mask.unsqueeze(2)
592
 
593
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
594
+ x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
595
  x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
596
  x = self.wo(x.flatten(2))
597
  return x