基于 DSpark 的投机解码训练框架原理与实现(论文+代码对照)

结合 DSpark 论文与代码实现,全面剖析 DeepSpec 的工作原理与核心组件。

项目地址:https://github.com/deepseek-ai/DeepSpec DSpark 论文:https://github.com/deepseek-ai/DeepSpec/blob/main/DSpark_paper.pdf

DSpark 是 DeepSeek 提出的一套无损加速大模型推理的“看人下菜碟”机制。 传统加速手段(推测解码)通常是让小模型一次性盲目盲猜一大串后续 Token,再让大模型统一验证。但这存在两个痛点:小模型猜得越往后越不准(多模态冲突导致“后缀衰减”);高并发时,大模型花大力气去验证那些猜得不准的 Token,会严重压垮系统吞吐。

DSpark 的核心突破就在于两点:

  1. 猜得更准(半自回归): 它在原有的单次并行生成网络后,拼了一个极轻量的小尾巴(顺序头),在几乎不增加延迟的情况下,让后面的 Token 能根据前面猜出的 Token 进行自适应修正,大幅提升长序列的猜测准确度。
  2. 动态裁剪(置信度调度): 它能实时感知系统的硬件负载与并发压力。如果并发高、大模型很忙,或者发现后面小模型猜的置信度太低,它就会果断把不靠谱的后缀砍掉,只送靠谱的前缀给大模型验证。

通过这种“高质量猜测”与“负载感知动态裁剪”的结合,DSpark 在保障大模型输出质量完全无损的前提下,成功在 DeepSeek 真实的高并发生产环境中,让用户端生成速度暴增了 60% ~ 85%。

graph TD
    A[Target Model 验证结束] -->|输出 Anchor Token D| B(DSpark 级联生成)
    subgraph DSpark 草稿生成阶段
        B --> C[并行骨干网络 Parallel Backbone]
        C -->|输出 Base Logits & 隐状态| D[轻量化顺序头 Sequential Block]
        D -->|注入局部转移偏置| E[生成候选序列 E, F, G, H]
        E --> F[置信度头 Confidence Head]
        F -->|评估序列条件生存概率 C1-C4| G[连续前缀校准 STS]
    end
    G --> H[硬件感知前缀调度器 Hardware-Aware Prefix Scheduler]
    H -->|结合实时系统吞吐曲线 SPS| I{动态截断决策}
    I -->|保留高回报前缀 E, F, G| J[Target Model 并行验证]
    I -->|裁剪低置信度后缀 H| K[释放验证算力/Batch容量]

目录

  1. 项目概述
  2. 背景:投机解码(Speculative Decoding)
  3. DSpark 算法原理(论文核心)
  4. 核心组件与代码架构
  5. 三种算法对比:DSpark / DFlash / Eagle3
  6. 端到端数据流
  7. 训练流程深度剖析
  8. 推理与评估流程
  9. 关键设计决策
  10. 总结

1. 项目概述

DeepSpec 是一个用于训练与评估投机解码(Speculative Decoding)草稿模型的全栈代码库,由 DeepSeek-AI 团队开源。

核心问题

大型语言模型(LLM)推理时逐 token 生成,延迟与输出长度成正比。投机解码通过引入一个轻量草稿模型(Draft Model)并行提出多个候选 token,再由目标模型(Target Model,即完整大模型)单次前向传播批量验证,从而加速推理。

项目组成

DeepSpec/
├── 三种草稿模型算法
│   ├── DSpark   — 半自回归 + 置信度调度(最新、最强)
│   ├── DFlash   — 纯并行生成(DSpark 的简化版)
│   └── Eagle3   — 自回归 + TTT(Test-Time Training)
├── 完整训练管线(数据准备 → 训练 → 评估)
├── 支持的目标模型:Qwen3 (4B/8B/14B)、Gemma4 (12B)
└── 评估基准:gsm8k, math500, aime25, humaneval, mbpp, ...

2. 背景:投机解码

2.1 基本原理

sequenceDiagram
    participant T as Target Model (大模型)
    participant D as Draft Model (小模型)
    participant O as 输出

    Note over T,D: 每个解码周期"
    D->>T: 提出 γ 个候选 token (并行)
    T->>T: 单次前向传播验证全部候选
    T->>O: 接受最长一致前缀 + 1 个 bonus token
    Note over O: 分布完全保留(拒绝采样保证)

2.2 延迟公式

论文中的核心公式:

L=Tdraft+γτ\mathcal{L} = T_{\text{draft}} + \frac{\gamma}{\tau}
  • TdraftT_{\text{draft}}:草稿生成延迟
  • γ\gamma:提案 token 数
  • τ\tau:接受长度(expected accepted tokens per round)

目标:最小化 L\mathcal{L},即同时降低 TdraftT_{\text{draft}} 和增大 τ\tau

2.3 两类草稿模型的设计权衡

graph LR
    subgraph 自回归["Autoregressive Drafters (Eagle3)"]
        A1["T_draft ∝ γ<br/>延迟随块大小线性增长"]
        A2["τ 高<br/>逐 token 条件依赖建模好"]
        A3["γ 与模型深度受限<br/>被迫使用小 γ 和浅层架构"]
    end

    subgraph 并行["Parallel Drafters (DFlash)"]
        P1["T_draft ≈ 常数<br/>单次前向传播"]
        P2["τ 衰减快<br/>无块内 token 依赖"]
        P3["可用大 γ + 深层架构"]
    end

    subgraph DSpark["DSpark (本文方法)"]
        D1["并行骨干 → T_draft 低"]
        D2["顺序头 → τ 衰减慢"]
        D3["置信度调度 → 验证浪费少"]
    end

3. DSpark 算法原理

参考论文:DSpark: Confidence-Scheduled Speculative Decoding with Semi-Autoregressive Generation (DeepSeek-AI, 2026)

3.1 整体架构

flowchart TB
    subgraph Input["输入"]
        Prompt["Prompt Tokens: A B C"]
    end

    subgraph TargetStep["Target Model 前向(1步)"]
        T1["生成 token D<br/>(作为 Anchor)"]
    end

    subgraph DraftPhase["草稿生成阶段"]
        direction TB
        subgraph ParallelStage["① 并行骨干(Parallel Backbone)"]
            PB["单次前向传播<br/>输入: D + [MASK]×γ<br/>输出: U₁...Uγ (base logits)<br/>+ h₁...hγ (hidden states)"]
        end
        subgraph SequentialStage["② 顺序修正头(Sequential Head)"]
            SH["逐位置自回归采样<br/>pₖ(v) ∝ exp(Uₖ(v) + Bₖ(x₀,x<ₖ,v))<br/>Markov Head / RNN Head"]
        end
        subgraph ConfidenceHead["③ 置信度头(Confidence Head)"]
            CH["预测 cₖ ∈ (0,1)<br/>条件接受概率估计"]
        end
        PB --> SH --> CH
    end

    subgraph Scheduling["硬件感知前缀调度器"]
        HS["动态截断低置信度后缀<br/>最大化全局吞吐 Θ = τ·SPS(B)"]
    end

    subgraph VerifyPhase["Target Model 验证"]
        V1["批量验证调度后的前缀<br/>拒绝采样"]
    end

    Prompt --> TargetStep --> DraftPhase --> Scheduling --> VerifyPhase --> |"下一轮"| TargetStep

    style ParallelStage fill:#e1f5fe
    style SequentialStage fill:#f3e5f5
    style ConfidenceHead fill:#e8f5e9
    style Scheduling fill:#fff3e0

3.2 半自回归生成(Semi-Autoregressive Generation)

DSpark 将草稿生成拆分为两个级联阶段,兼具并行的高效与自回归的质量:

阶段一:并行骨干

以 Anchor token(上一轮验证通过的最后一个 token)为条件,一次性前向传播预测整个草稿块:

输入: x0(anchor)+[MASK,,MASK]\text{输入: } x_0(\text{anchor}) + [\text{MASK}, \dots, \text{MASK}] 输出: U1,,Uγ(每个位置的 base logits)\text{输出: } U_1, \dots, U_\gamma \quad \text{(每个位置的 base logits)}

代码对应deepspec/modeling/dspark/qwen3/modeling.py::Qwen3DSparkModel.forward):

# 1. 构建噪声嵌入:anchor位置放anchor token,其余放mask token
noise_embedding = create_noise_embed(...)

# 2. 并行前向传播(所有草稿位置一次计算)
output_hidden = self._forward_backbone(
    noise_embedding=noise_embedding,
    target_hidden_states=target_hidden_states,  # KV注入
    attention_mask=dspark_attn_mask,
)
# output_hidden: [B, num_anchors * block_size, D]

阶段二:顺序修正头

并行骨干的输出 UkU_k 缺乏对块内前文的条件依赖。顺序头通过偏置项 BkB_k 引入依赖:

pk(v)=softmax(Uk(v)+Bk(x0,x<k,v))p_k(v) = \text{softmax}\left(U_k(v) + B_k(x_0, x_{<k}, v)\right)

三种顺序头实现deepspec/modeling/dspark/markov_head.py):

graph LR
    subgraph Markov["VanillaMarkovHead<br/>(论文公式5)"]
        M1["B(xₖ₋₁, ·) = W₂ · W₁[xₖ₋₁]<br/>仅依赖前一个token<br/>低秩分解 r=256"]
    end
    subgraph Gated["GatedMarkovHead"]
        G1["B = σ(Wg·[hₖ;W₁[xₖ₋₁]])<br/>⊗ W₁[xₖ₋₁]<br/>引入隐藏状态门控"]
    end
    subgraph RNN["RNNHead<br/>(论文公式6)"]
        R1["sₖ = σ(Wg·zₖ)⊗sₖ₋₁<br/>+ (1-σ)(Wg·zₖ)⊗tanh(Wc·zₖ)<br/>zₖ = [sₖ₋₁; W₁[xₖ₋₁]; hₖ]<br/>GRU式循环状态"]
    end

推理时的逐位置采样markov_head.py::sample_block_tokens):

# 伪代码
for k in range(block_size):
    step_logits = base_logits[:, k, :] + B_k(prev_token, hidden_k)
    next_token = sample(step_logits)
    # 将 next_token 用于下一步的 B_{k+1} 计算

3.3 置信度调度(Confidence-Scheduled Verification)

为什么需要置信度调度?

  • 并行草稿块越长,后缀位置的接受率越低
  • 在高并发服务中,验证每个额外 token 都占用目标模型 batch 容量
  • 盲目验证所有提案 token 是浪费

置信度头设计

flowchart LR
    H["草稿位置 k 的隐藏状态 hₖ"] --> F["线性投影 + Sigmoid"]
    M["前一个 token 的 Markov 嵌入<br/>W₁[xₖ₋₁]"] --> F
    F --> C["cₖ ∈ (0,1)<br/>条件接受概率预测"]

    style C fill:#e8f5e9

监督信号(解析接受率,论文公式8):

ck=112pd(k)pt(k)1c_k^* = 1 - \frac{1}{2} \| p_d^{(k)} - p_t^{(k)} \|_1

即 draft 分布与 target 分布之间的 Total Variation Distance 的补集。

代码对应deepspec/modeling/dspark/loss.py::_compute_accept_rate_3d):

draft_probs = torch.softmax(draft_logits, dim=-1)
target_probs = torch.softmax(aligned_target_logits, dim=-1)
accept_rate_3d = 1.0 - 0.5 * (draft_probs - target_probs).abs().sum(dim=-1)

硬件感知前缀调度器(Algorithm 1)

flowchart TB
    Input["每个请求的置信度序列<br/>cᵣ,₁, cᵣ,₂, ..."] --> Sort["按生存概率 aᵣ,ⱼ 降序排序<br/>aᵣ,ⱼ = Πᵢ≤ⱼ cᵣ,ᵢ"]
    Sort --> Greedy["贪心扫描"]
    Greedy --> Check{"Θ > Θ_best?<br/>(Θ = τ*SPS(B))"}
    Check -->|"是"| Update["接受该 (r,j) 对<br/>更新 ℓ*ᵣ"]
    Check -->|"否"| Break["提前终止(保证因果性)"]
    Update --> Greedy

    style Check fill:#fff3e0
    style Break fill:#ffebee

关键设计:贪心扫描中一旦吞吐下降立即终止,确保调度决策不依赖于未来 token,满足非预期性质(non-anticipating property),保证拒绝采样正确性。


4. 核心组件与代码架构

4.1 完整模块依赖图

graph TB
    subgraph Entry["入口"]
        train_py["train.py"]
        eval_py["eval.py"]
    end

    subgraph Data["数据层 deepspec/data/"]
        TCD["target_cache_dataset.py<br/>目标缓存 mmap 读取/写入"]
        JDS["jsonl_dataset.py<br/>JSONL 内存映射读取器"]
        Parser["parser.py<br/>对话模板解析 + loss_mask 生成"]
        Prefetch["cuda_prefetcher.py<br/>GPU 异步数据预取"]
    end

    subgraph Modeling["模型层 deepspec/modeling/"]
        subgraph DSparkM["dspark/"]
            Common["common.py<br/>锚点采样、注意力掩码、噪声嵌入"]
            Markov["markov_head.py<br/>顺序修正头 (3种)"]
            DLoss["loss.py<br/>CE+L1+Confidence 联合损失"]
            DQwen["qwen3/modeling.py<br/>Qwen3DSparkModel"]
            DGemma["gemma4/modeling.py<br/>Gemma4DSparkModel"]
        end
        subgraph EagleM["eagle3/"]
            ECommon["common.py<br/>Eagle3 注意力掩码"]
            ELoss["loss.py<br/>TTT循环 + Triton融合损失"]
            EQwen["qwen3/modeling.py<br/>Qwen3Eagle3Model"]
        end
    end

    subgraph Trainer["训练层 deepspec/trainer/"]
        BaseT["base_trainer.py<br/>FSDP、梯度累积、主循环"]
        DSparkT["dspark_trainer.py<br/>DSpark 特定训练逻辑"]
        CKPT["ckpt_manager.py<br/>检查点保存/恢复"]
    end

    subgraph Eval["评估层 deepspec/eval/"]
        BaseE["base_evaluator.py<br/>投机解码循环 + 拒绝采样"]
        DSEval["dspark/evaluator.py<br/>DSpark 推理"]
        DSOps["dspark/draft_ops.py<br/>草稿生成操作"]
        EvalConf["dspark/confidence_head.py<br/>置信度校准评估"]
    end

    subgraph Utils["工具层 deepspec/utils/"]
        Config["config.py<br/>ConfigNode 配置系统"]
        Dist["distributed.py<br/>分布式初始化 + 可恢复采样器"]
        Optim["optim.py<br/>BF16Optimizer"]
        Sampling["sampling.py<br/>概率采样、残差采样"]
    end

    train_py --> BaseT
    eval_py --> BaseE
    BaseT --> DSparkM
    BaseT --> EagleM
    BaseE --> DSparkM
    BaseE --> EagleM
    DSparkT --> DQwen
    DSEval --> DQwen

    style DSparkM fill:#e1f5fe
    style EagleM fill:#f3e5f5
    style BaseE fill:#e8f5e9

4.2 配置系统

所有算法/模型的超参数统一通过 Python 配置文件中转,由 ConfigNode(支持点访问的 dict)加载。

典型 DSpark 配置结构config/dspark/dspark_qwen3_4b.py):

配置项含义DSpark 典型值DFlashEagle3
block_size每 anchor 草稿长度77N/A
num_draft_layers草稿模型 Transformer 层数551
target_layer_ids提取的目标模型隐藏层[1,9,17,25,33]同左同左
markov_rankMarkov 头低秩维度2560(关闭)N/A
ce_loss_alphaCE 损失权重0.11.0-
l1_loss_alphaL1 分布匹配权重0.90.0(关闭)-
confidence_head_alpha置信度损失权重1.00.0(关闭)N/A
loss_decay_gamma位置衰减参数4.0--

5. 三种算法对比

5.1 架构对比

graph LR
    subgraph DFlashArch["DFlash (纯并行)"]
        D1["输入: anchor + [MASK]×γ"]
        D2["单次前向 → 所有位置 logits"]
        D3["无块内依赖 → 接受率衰减快"]
        D1 --> D2 --> D3
    end

    subgraph DSparkArch["DSpark (半自回归)"]
        S1["输入: anchor + [MASK]×γ"]
        S2["并行骨干 → base logits"]
        S3["顺序头 → 加入块内依赖偏置"]
        S4["置信度头 → 预测接受概率"]
        S1 --> S2 --> S3 --> S4
    end

    subgraph Eagle3Arch["Eagle3 (自回归+TTT)"]
        E1["输入: 前一隐藏状态"]
        E2["逐 token 自回归生成"]
        E3["TTT循环: 用target logits蒸馏"]
        E4["严格自回归 → τ高但T_draft大"]
        E1 --> E2 --> E3 --> E4
    end

    Eagle3Arch ~~~ DFlashArch ~~~ DSparkArch

5.2 性能对比(来自论文)

指标Eagle3 (AR)DFlash (Parallel)DSpark (Ours)
接受长度 (Qwen3-4B)基线+16.3% over Eagle3+30.9% over Eagle3
草稿延迟∝ γ(线性)≈ 常数≈ 常数
块内依赖建模完整轻量(顺序头)
置信度调度

6. 端到端数据流

6.1 完整流水线

flowchart TD
    subgraph Phase1["阶段1: 数据准备"]
        D1["下载 Prompt 数据集<br/>mlabonne/open-perfectblend<br/>~130万条多轮对话"]
        D2["切分: 95% 训练 / 5% 评估"]
        D3["启动 SGLang 推理服务<br/>(8 GPU workers)"]
        D4["用目标模型重生成答案<br/>generate_train_data.py"]
        D5["构建目标缓存 Target Cache<br/>prepare_target_cache.py<br/>⚠️ 约 38TB (Qwen3-4B)"]
        D1 --> D2 --> D3 --> D4 --> D5
    end

    subgraph Phase2["阶段2: 训练"]
        T1["加载 Target Cache<br/>(mmap 按需读取)"]
        T2["每样本随机采样多个 Anchor 位置"]
        T3["DSpark 前向传播<br/>并行骨干 + 顺序头"]
        T4["计算联合损失<br/>CE + L1 + Confidence"]
        T5["FSDP 分布式训练"]
        T6["保存检查点<br/>~/checkpoints/deepspec/"]
        T1 --> T2 --> T3 --> T4 --> T5 --> T6
    end

    subgraph Phase3["阶段3: 评估"]
        E1["加载训练好的草稿模型"]
        E2["逐数据集运行投机解码"]
        E3["拒绝采样验证"]
        E4["输出指标:<br/>acceptance_length<br/>verify_rate<br/>per-position accept_rates"]
        E1 --> E2 --> E3 --> E4
    end

    Phase1 --> Phase2 --> Phase3

    style D5 fill:#fff3e0
    style T5 fill:#e1f5fe
    style E2 fill:#e8f5e9

6.2 Target Cache 二进制格式

这是整个数据管线的核心设计(代码:deepspec/data/target_cache_dataset.py)。

graph LR
    subgraph Manifest["manifest.json"]
        M1["版本、样本数、分片数"]
        M2["target_layer_ids"]
        M3["hidden_size, dtype"]
    end

    subgraph Index["samples.idx (固定大小)"]
        I1["每条记录 64 字节"]
        I2["sample_id (8B)"]
        I3["shard_id (4B) + seq_len (4B)"]
        I4["各字段 offset (8B×5)"]
    end

    subgraph Shards["shard-00000.bin, ..."]
        S1["mmap 直接读取"]
        S2["变长样本连续存储"]
    end

    Manifest --> Index --> Shards

存储内容(每样本):

  • input_ids:原始 token IDs
  • loss_mask:哪些位置的 target 输出需要计算损失
  • target_hidden_states:指定目标层的隐藏状态(用于 KV 注入)
  • target_last_hidden_states:目标模型最后一层隐藏状态(用于对齐 logits)

7. 训练流程深度剖析

7.1 训练主循环

flowchart TD
    Start["torch.multiprocessing.spawn<br/>每 GPU 一个 worker"] --> Init["初始化分布式环境"]
    Init --> Build["build_models()"]
    Build --> LoadCache["加载 Target Cache<br/>(CacheDataset, mmap)"]
    LoadCache --> Loop["训练循环"]

    subgraph LoopDetail["每个 batch"]
        L1["CUDAPrefetcher 异步 H2D 传输"]
        L2["前向传播<br/>sample_anchor_positions()"]
        L3["构建 DSpark 注意力掩码"]
        L4["_forward_backbone()<br/>并行骨干前向"]
        L5["apply_block_logits()<br/>顺序头修正"]
        L6["compute_dspark_loss()<br/>联合损失"]
        L7["反向传播 + 梯度累积"]
        L8{"该同步?"}
        L9["clip_grad_norm_ + optimizer.step()"]
        L10["记录指标 + 保存检查点"]
        L1 --> L2 --> L3 --> L4 --> L5 --> L6 --> L7 --> L8
        L8 -->|"是"| L9 --> L10
        L8 -->|"否"| L1
    end

    Loop --> Final["训练完成<br/>clean_up()"]

7.2 联合损失函数详解

代码对应deepspec/modeling/dspark/loss.py::compute_dspark_loss):

L = α_ce × L_ce + α_l1 × L_l1 + α_conf × L_conf

其中:

损失项公式作用权重 (DSpark)
CE LossCrossEntropy(draft_logits, target_ids)稀疏监督信号,学正确 token0.1
L1 Loss12pdpt1\frac{1}{2}\sum\|p_d - p_t\|_1稠密分布匹配,学整体分布0.9
Confidence LossBCE(confidence_pred, accept_rate*)校准置信度头1.0

位置衰减权重

# loss.py::_build_loss_weight_mask
weights_k = exp(-(k-1) / loss_decay_gamma)  # γ = 4.0

块内越靠后的位置权重越低,因为后续位置的监督信号天然更不可靠。

7.3 分布式训练要点

  • FSDPno_shard / full_shard / hybrid_shard
  • BF16Optimizer:维护 fp32 master 参数,每步拷贝 bf16 梯度
  • StatelessResumableDistributedSampler:确定性 shuffle,支持跨 epoch 连续采样
  • 梯度累积global_batch_size / (world_size * local_batch_size)

8. 推理与评估流程

8.1 投机解码循环(核心)

这是 deepspec/eval/base_evaluator.py::generate_decoding_sample() 的完整逻辑:

stateDiagram-v2
    [*] --> Prefill: 输入 prompt
    Prefill --> InitContext: Target Model 前向<br/>提取隐藏状态
    InitContext --> Propose: 草稿模型生成候选 token
    Propose --> Verify: Target Model 验证
    Verify --> AcceptCheck: 拒绝采样
    AcceptCheck --> |"全部接受"| Propose
    AcceptCheck --> |"部分接受"| SampleResidual: 从残差分布采样
    SampleResidual --> Update: 更新 KV Cache<br/>更新隐藏状态
    Update --> Propose: 下一轮草稿
    AcceptCheck --> |"全部拒绝"| SampleResidual
    SampleResidual --> CheckStop: 检查停止 token
    CheckStop --> |"未停止"| Update
    CheckStop --> |"停止"| [*]

8.2 拒绝采样验证详解

代码对应base_evaluator.py::verify_draft_tokens):

# 对每个草稿位置 k:
accept_prob_k = min(1, p_target(token_k) / p_draft(token_k))
if random() < accept_prob_k:
    accept token_k
else:
    reject token_k and all subsequent tokens
    # 从残差分布采样下一个 token
    next_token ~ max(0, p_target - p_draft)

数学保证:拒绝采样确保输出分布与 target model 分布完全一致(无质量损失)。

8.3 DSpark 推理特化

flowchart TD
    Init["_init_context()"] --> |"提取 target hidden states"| Prop["propose()"]
    Prop --> Draft["forward_dspark_draft_block()"]
    Draft --> Markov["apply Markov Head<br/>(如果启用)"]
    Markov --> Conf["预测置信度 cₖ"]
    Conf --> Schedule["按置信度截断草稿"]
    Schedule --> Verify["verify_draft_tokens()"]
    Verify --> Update["_update()<br/>提取新隐藏状态"]
    Update --> Prop

9. 关键设计决策

9.1 KV 注入(Context Feature Injection)

DSpark(继承自 DFlash)的核心设计:将目标模型多层隐藏状态注入草稿模型的注意力机制。

graph LR
    subgraph Target["目标模型"]
        T1["Layer 1 → H₁"]
        T2["Layer 9 → H₉"]
        T3["Layer 17 → H₁₇"]
        T4["Layer 25 → H₂₅"]
        T5["Layer 33 → H₃₃"]
    end

    subgraph Projection["特征投影"]
        P1["Concat(H₁,H₉,H₁₇,H₂₅,H₃₃)"]
        P2["RMSNorm"]
        P3["fc: ℝ^{5D} → ℝ^D"]
    end

    subgraph DraftAttn["DSpark Attention"]
        A1["Q: 仅来自草稿 token"]
        A2["K/V: [Target_Context; Draft_Block]"]
        A3["注意力掩码:<br/>草稿位置单向可见<br/>context 双向可见"]
    end

    T1 & T2 & T3 & T4 & T5 --> P1 --> P2 --> P3 --> A2
    A1 --> A3

代码对应qwen3/modeling.py::Qwen3DSparkAttention.forward):

# Q 来自草稿隐藏状态
q = self.q_proj(hidden_states)

# K/V 拼接 target context + 草稿
k = torch.cat([self.k_proj(target_hidden_states),
               self.k_proj(hidden_states)], dim=1)
v = torch.cat([self.v_proj(target_hidden_states),
               self.v_proj(hidden_states)], dim=1)

# 使用 DSpark 专用注意力掩码(flex_attention)
attn_output = flex_attention(q, k, v, attn_mask=dspark_attn_mask)

9.2 注意力掩码设计

DSpark 的注意力掩码确保每个草稿位置:

  1. 可以看到 context 中 anchor 之前的全部位置
  2. 可以看到同 block 内之前的草稿位置(引入块内依赖)
  3. 不能看到 context 中 anchor 之后的位置(保持因果性)
graph TD
    subgraph Mask["DSpark 注意力掩码(单一样本)"]
        direction TB
        A["Context: [A B C D] (anchor=D)"]
        B["Draft Block: [E F G H] (block_size=4)"]
        A --> B

        M["注意力权限矩阵:<br/>E: 看 A B C D<br/>F: 看 A B C D + E<br/>G: 看 A B C D + E + F<br/>H: 看 A B C D + E + F + G"]
    end

9.3 可恢复的分布式采样器

StatelessResumableDistributedSampler 是一个精妙设计:

  • 只依赖 start_global_offset_samples 和 deterministic seed
  • 可以从任意位置恢复训练(支持容错)
  • 每个 epoch 使用不同 seed 的 randperm,保证覆盖全量数据

10. 总结

10.1 DSpark 的技术创新

mindmap
  root((DSpark))
    半自回归生成
      并行骨干<br/>低延迟
      顺序修正头<br/>高质量
      Markov/RNN Head<br/>轻量依赖建模
    置信度调度
      线性置信度头
      解析接受率监督
      硬件感知前缀调度
      贪心截断最大化吞吐
    工程实现
      Target Cache mmap
      FSDP 分布式训练
      flex_attention 高效稀疏注意力
      Triton 融合损失

10.2 代码阅读路径推荐

顺序文件目的
1README.md项目概览
2config/dspark/dspark_qwen3_4b.py理解全部超参数
3deepspec/data/target_cache_dataset.py理解数据格式
4deepspec/modeling/dspark/common.py锚点采样、注意力掩码
5deepspec/modeling/dspark/qwen3/modeling.pyDSpark 模型完整前向
6deepspec/modeling/dspark/markov_head.py顺序修正头实现
7deepspec/modeling/dspark/loss.py联合训练目标
8deepspec/trainer/base_trainer.py训练主循环
9deepspec/eval/base_evaluator.py投机解码 + 拒绝采样
10DSpark_paper.pdf论文原始描述

10.3 项目意义

DeepSpec 不仅开源了 DSpark 的训练代码,还一并开源了 DFlashEagle3 的实现,为投机解码研究提供了一个统一的、算法驱动的训练框架。结合 DeepSeek-V4 的生产部署结果(60%-85% 的单用户加速),DSpark 已被验证为当前最先进的投机解码算法之一。