基于 DSpark 的投机解码训练框架原理与实现(论文+代码对照)
结合 DSpark 论文与代码实现,全面剖析 DeepSpec 的工作原理与核心组件。
项目地址:https://github.com/deepseek-ai/DeepSpec DSpark 论文:https://github.com/deepseek-ai/DeepSpec/blob/main/DSpark_paper.pdf
DSpark 是 DeepSeek 提出的一套无损加速大模型推理的“看人下菜碟”机制。 传统加速手段(推测解码)通常是让小模型一次性盲目盲猜一大串后续 Token,再让大模型统一验证。但这存在两个痛点:小模型猜得越往后越不准(多模态冲突导致“后缀衰减”);高并发时,大模型花大力气去验证那些猜得不准的 Token,会严重压垮系统吞吐。
DSpark 的核心突破就在于两点:
- 猜得更准(半自回归): 它在原有的单次并行生成网络后,拼了一个极轻量的小尾巴(顺序头),在几乎不增加延迟的情况下,让后面的 Token 能根据前面猜出的 Token 进行自适应修正,大幅提升长序列的猜测准确度。
- 动态裁剪(置信度调度): 它能实时感知系统的硬件负载与并发压力。如果并发高、大模型很忙,或者发现后面小模型猜的置信度太低,它就会果断把不靠谱的后缀砍掉,只送靠谱的前缀给大模型验证。
通过这种“高质量猜测”与“负载感知动态裁剪”的结合,DSpark 在保障大模型输出质量完全无损的前提下,成功在 DeepSeek 真实的高并发生产环境中,让用户端生成速度暴增了 60% ~ 85%。
graph TD
A[Target Model 验证结束] -->|输出 Anchor Token D| B(DSpark 级联生成)
subgraph DSpark 草稿生成阶段
B --> C[并行骨干网络 Parallel Backbone]
C -->|输出 Base Logits & 隐状态| D[轻量化顺序头 Sequential Block]
D -->|注入局部转移偏置| E[生成候选序列 E, F, G, H]
E --> F[置信度头 Confidence Head]
F -->|评估序列条件生存概率 C1-C4| G[连续前缀校准 STS]
end
G --> H[硬件感知前缀调度器 Hardware-Aware Prefix Scheduler]
H -->|结合实时系统吞吐曲线 SPS| I{动态截断决策}
I -->|保留高回报前缀 E, F, G| J[Target Model 并行验证]
I -->|裁剪低置信度后缀 H| K[释放验证算力/Batch容量]
目录
- 项目概述
- 背景:投机解码(Speculative Decoding)
- DSpark 算法原理(论文核心)
- 核心组件与代码架构
- 三种算法对比:DSpark / DFlash / Eagle3
- 端到端数据流
- 训练流程深度剖析
- 推理与评估流程
- 关键设计决策
- 总结
1. 项目概述
DeepSpec 是一个用于训练与评估投机解码(Speculative Decoding)草稿模型的全栈代码库,由 DeepSeek-AI 团队开源。
核心问题
大型语言模型(LLM)推理时逐 token 生成,延迟与输出长度成正比。投机解码通过引入一个轻量草稿模型(Draft Model)并行提出多个候选 token,再由目标模型(Target Model,即完整大模型)单次前向传播批量验证,从而加速推理。
项目组成
DeepSpec/
├── 三种草稿模型算法
│ ├── DSpark — 半自回归 + 置信度调度(最新、最强)
│ ├── DFlash — 纯并行生成(DSpark 的简化版)
│ └── Eagle3 — 自回归 + TTT(Test-Time Training)
├── 完整训练管线(数据准备 → 训练 → 评估)
├── 支持的目标模型:Qwen3 (4B/8B/14B)、Gemma4 (12B)
└── 评估基准:gsm8k, math500, aime25, humaneval, mbpp, ...
2. 背景:投机解码
2.1 基本原理
sequenceDiagram
participant T as Target Model (大模型)
participant D as Draft Model (小模型)
participant O as 输出
Note over T,D: 每个解码周期"
D->>T: 提出 γ 个候选 token (并行)
T->>T: 单次前向传播验证全部候选
T->>O: 接受最长一致前缀 + 1 个 bonus token
Note over O: 分布完全保留(拒绝采样保证)
2.2 延迟公式
论文中的核心公式:
- :草稿生成延迟
- :提案 token 数
- :接受长度(expected accepted tokens per round)
目标:最小化 ,即同时降低 和增大 。
2.3 两类草稿模型的设计权衡
graph LR
subgraph 自回归["Autoregressive Drafters (Eagle3)"]
A1["T_draft ∝ γ<br/>延迟随块大小线性增长"]
A2["τ 高<br/>逐 token 条件依赖建模好"]
A3["γ 与模型深度受限<br/>被迫使用小 γ 和浅层架构"]
end
subgraph 并行["Parallel Drafters (DFlash)"]
P1["T_draft ≈ 常数<br/>单次前向传播"]
P2["τ 衰减快<br/>无块内 token 依赖"]
P3["可用大 γ + 深层架构"]
end
subgraph DSpark["DSpark (本文方法)"]
D1["并行骨干 → T_draft 低"]
D2["顺序头 → τ 衰减慢"]
D3["置信度调度 → 验证浪费少"]
end
3. DSpark 算法原理
参考论文:DSpark: Confidence-Scheduled Speculative Decoding with Semi-Autoregressive Generation (DeepSeek-AI, 2026)
3.1 整体架构
flowchart TB
subgraph Input["输入"]
Prompt["Prompt Tokens: A B C"]
end
subgraph TargetStep["Target Model 前向(1步)"]
T1["生成 token D<br/>(作为 Anchor)"]
end
subgraph DraftPhase["草稿生成阶段"]
direction TB
subgraph ParallelStage["① 并行骨干(Parallel Backbone)"]
PB["单次前向传播<br/>输入: D + [MASK]×γ<br/>输出: U₁...Uγ (base logits)<br/>+ h₁...hγ (hidden states)"]
end
subgraph SequentialStage["② 顺序修正头(Sequential Head)"]
SH["逐位置自回归采样<br/>pₖ(v) ∝ exp(Uₖ(v) + Bₖ(x₀,x<ₖ,v))<br/>Markov Head / RNN Head"]
end
subgraph ConfidenceHead["③ 置信度头(Confidence Head)"]
CH["预测 cₖ ∈ (0,1)<br/>条件接受概率估计"]
end
PB --> SH --> CH
end
subgraph Scheduling["硬件感知前缀调度器"]
HS["动态截断低置信度后缀<br/>最大化全局吞吐 Θ = τ·SPS(B)"]
end
subgraph VerifyPhase["Target Model 验证"]
V1["批量验证调度后的前缀<br/>拒绝采样"]
end
Prompt --> TargetStep --> DraftPhase --> Scheduling --> VerifyPhase --> |"下一轮"| TargetStep
style ParallelStage fill:#e1f5fe
style SequentialStage fill:#f3e5f5
style ConfidenceHead fill:#e8f5e9
style Scheduling fill:#fff3e0
3.2 半自回归生成(Semi-Autoregressive Generation)
DSpark 将草稿生成拆分为两个级联阶段,兼具并行的高效与自回归的质量:
阶段一:并行骨干
以 Anchor token(上一轮验证通过的最后一个 token)为条件,一次性前向传播预测整个草稿块:
代码对应(deepspec/modeling/dspark/qwen3/modeling.py::Qwen3DSparkModel.forward):
# 1. 构建噪声嵌入:anchor位置放anchor token,其余放mask token
noise_embedding = create_noise_embed(...)
# 2. 并行前向传播(所有草稿位置一次计算)
output_hidden = self._forward_backbone(
noise_embedding=noise_embedding,
target_hidden_states=target_hidden_states, # KV注入
attention_mask=dspark_attn_mask,
)
# output_hidden: [B, num_anchors * block_size, D]
阶段二:顺序修正头
并行骨干的输出 缺乏对块内前文的条件依赖。顺序头通过偏置项 引入依赖:
三种顺序头实现(deepspec/modeling/dspark/markov_head.py):
graph LR
subgraph Markov["VanillaMarkovHead<br/>(论文公式5)"]
M1["B(xₖ₋₁, ·) = W₂ · W₁[xₖ₋₁]<br/>仅依赖前一个token<br/>低秩分解 r=256"]
end
subgraph Gated["GatedMarkovHead"]
G1["B = σ(Wg·[hₖ;W₁[xₖ₋₁]])<br/>⊗ W₁[xₖ₋₁]<br/>引入隐藏状态门控"]
end
subgraph RNN["RNNHead<br/>(论文公式6)"]
R1["sₖ = σ(Wg·zₖ)⊗sₖ₋₁<br/>+ (1-σ)(Wg·zₖ)⊗tanh(Wc·zₖ)<br/>zₖ = [sₖ₋₁; W₁[xₖ₋₁]; hₖ]<br/>GRU式循环状态"]
end
推理时的逐位置采样(markov_head.py::sample_block_tokens):
# 伪代码
for k in range(block_size):
step_logits = base_logits[:, k, :] + B_k(prev_token, hidden_k)
next_token = sample(step_logits)
# 将 next_token 用于下一步的 B_{k+1} 计算
3.3 置信度调度(Confidence-Scheduled Verification)
为什么需要置信度调度?
- 并行草稿块越长,后缀位置的接受率越低
- 在高并发服务中,验证每个额外 token 都占用目标模型 batch 容量
- 盲目验证所有提案 token 是浪费
置信度头设计
flowchart LR
H["草稿位置 k 的隐藏状态 hₖ"] --> F["线性投影 + Sigmoid"]
M["前一个 token 的 Markov 嵌入<br/>W₁[xₖ₋₁]"] --> F
F --> C["cₖ ∈ (0,1)<br/>条件接受概率预测"]
style C fill:#e8f5e9
监督信号(解析接受率,论文公式8):
即 draft 分布与 target 分布之间的 Total Variation Distance 的补集。
代码对应(deepspec/modeling/dspark/loss.py::_compute_accept_rate_3d):
draft_probs = torch.softmax(draft_logits, dim=-1)
target_probs = torch.softmax(aligned_target_logits, dim=-1)
accept_rate_3d = 1.0 - 0.5 * (draft_probs - target_probs).abs().sum(dim=-1)
硬件感知前缀调度器(Algorithm 1)
flowchart TB
Input["每个请求的置信度序列<br/>cᵣ,₁, cᵣ,₂, ..."] --> Sort["按生存概率 aᵣ,ⱼ 降序排序<br/>aᵣ,ⱼ = Πᵢ≤ⱼ cᵣ,ᵢ"]
Sort --> Greedy["贪心扫描"]
Greedy --> Check{"Θ > Θ_best?<br/>(Θ = τ*SPS(B))"}
Check -->|"是"| Update["接受该 (r,j) 对<br/>更新 ℓ*ᵣ"]
Check -->|"否"| Break["提前终止(保证因果性)"]
Update --> Greedy
style Check fill:#fff3e0
style Break fill:#ffebee
关键设计:贪心扫描中一旦吞吐下降立即终止,确保调度决策不依赖于未来 token,满足非预期性质(non-anticipating property),保证拒绝采样正确性。
4. 核心组件与代码架构
4.1 完整模块依赖图
graph TB
subgraph Entry["入口"]
train_py["train.py"]
eval_py["eval.py"]
end
subgraph Data["数据层 deepspec/data/"]
TCD["target_cache_dataset.py<br/>目标缓存 mmap 读取/写入"]
JDS["jsonl_dataset.py<br/>JSONL 内存映射读取器"]
Parser["parser.py<br/>对话模板解析 + loss_mask 生成"]
Prefetch["cuda_prefetcher.py<br/>GPU 异步数据预取"]
end
subgraph Modeling["模型层 deepspec/modeling/"]
subgraph DSparkM["dspark/"]
Common["common.py<br/>锚点采样、注意力掩码、噪声嵌入"]
Markov["markov_head.py<br/>顺序修正头 (3种)"]
DLoss["loss.py<br/>CE+L1+Confidence 联合损失"]
DQwen["qwen3/modeling.py<br/>Qwen3DSparkModel"]
DGemma["gemma4/modeling.py<br/>Gemma4DSparkModel"]
end
subgraph EagleM["eagle3/"]
ECommon["common.py<br/>Eagle3 注意力掩码"]
ELoss["loss.py<br/>TTT循环 + Triton融合损失"]
EQwen["qwen3/modeling.py<br/>Qwen3Eagle3Model"]
end
end
subgraph Trainer["训练层 deepspec/trainer/"]
BaseT["base_trainer.py<br/>FSDP、梯度累积、主循环"]
DSparkT["dspark_trainer.py<br/>DSpark 特定训练逻辑"]
CKPT["ckpt_manager.py<br/>检查点保存/恢复"]
end
subgraph Eval["评估层 deepspec/eval/"]
BaseE["base_evaluator.py<br/>投机解码循环 + 拒绝采样"]
DSEval["dspark/evaluator.py<br/>DSpark 推理"]
DSOps["dspark/draft_ops.py<br/>草稿生成操作"]
EvalConf["dspark/confidence_head.py<br/>置信度校准评估"]
end
subgraph Utils["工具层 deepspec/utils/"]
Config["config.py<br/>ConfigNode 配置系统"]
Dist["distributed.py<br/>分布式初始化 + 可恢复采样器"]
Optim["optim.py<br/>BF16Optimizer"]
Sampling["sampling.py<br/>概率采样、残差采样"]
end
train_py --> BaseT
eval_py --> BaseE
BaseT --> DSparkM
BaseT --> EagleM
BaseE --> DSparkM
BaseE --> EagleM
DSparkT --> DQwen
DSEval --> DQwen
style DSparkM fill:#e1f5fe
style EagleM fill:#f3e5f5
style BaseE fill:#e8f5e9
4.2 配置系统
所有算法/模型的超参数统一通过 Python 配置文件中转,由 ConfigNode(支持点访问的 dict)加载。
典型 DSpark 配置结构(config/dspark/dspark_qwen3_4b.py):
| 配置项 | 含义 | DSpark 典型值 | DFlash | Eagle3 |
|---|---|---|---|---|
block_size | 每 anchor 草稿长度 | 7 | 7 | N/A |
num_draft_layers | 草稿模型 Transformer 层数 | 5 | 5 | 1 |
target_layer_ids | 提取的目标模型隐藏层 | [1,9,17,25,33] | 同左 | 同左 |
markov_rank | Markov 头低秩维度 | 256 | 0(关闭) | N/A |
ce_loss_alpha | CE 损失权重 | 0.1 | 1.0 | - |
l1_loss_alpha | L1 分布匹配权重 | 0.9 | 0.0(关闭) | - |
confidence_head_alpha | 置信度损失权重 | 1.0 | 0.0(关闭) | N/A |
loss_decay_gamma | 位置衰减参数 | 4.0 | - | - |
5. 三种算法对比
5.1 架构对比
graph LR
subgraph DFlashArch["DFlash (纯并行)"]
D1["输入: anchor + [MASK]×γ"]
D2["单次前向 → 所有位置 logits"]
D3["无块内依赖 → 接受率衰减快"]
D1 --> D2 --> D3
end
subgraph DSparkArch["DSpark (半自回归)"]
S1["输入: anchor + [MASK]×γ"]
S2["并行骨干 → base logits"]
S3["顺序头 → 加入块内依赖偏置"]
S4["置信度头 → 预测接受概率"]
S1 --> S2 --> S3 --> S4
end
subgraph Eagle3Arch["Eagle3 (自回归+TTT)"]
E1["输入: 前一隐藏状态"]
E2["逐 token 自回归生成"]
E3["TTT循环: 用target logits蒸馏"]
E4["严格自回归 → τ高但T_draft大"]
E1 --> E2 --> E3 --> E4
end
Eagle3Arch ~~~ DFlashArch ~~~ DSparkArch
5.2 性能对比(来自论文)
| 指标 | Eagle3 (AR) | DFlash (Parallel) | DSpark (Ours) |
|---|---|---|---|
| 接受长度 (Qwen3-4B) | 基线 | +16.3% over Eagle3 | +30.9% over Eagle3 |
| 草稿延迟 | ∝ γ(线性) | ≈ 常数 | ≈ 常数 |
| 块内依赖建模 | 完整 | 无 | 轻量(顺序头) |
| 置信度调度 | ❌ | ❌ | ✅ |
6. 端到端数据流
6.1 完整流水线
flowchart TD
subgraph Phase1["阶段1: 数据准备"]
D1["下载 Prompt 数据集<br/>mlabonne/open-perfectblend<br/>~130万条多轮对话"]
D2["切分: 95% 训练 / 5% 评估"]
D3["启动 SGLang 推理服务<br/>(8 GPU workers)"]
D4["用目标模型重生成答案<br/>generate_train_data.py"]
D5["构建目标缓存 Target Cache<br/>prepare_target_cache.py<br/>⚠️ 约 38TB (Qwen3-4B)"]
D1 --> D2 --> D3 --> D4 --> D5
end
subgraph Phase2["阶段2: 训练"]
T1["加载 Target Cache<br/>(mmap 按需读取)"]
T2["每样本随机采样多个 Anchor 位置"]
T3["DSpark 前向传播<br/>并行骨干 + 顺序头"]
T4["计算联合损失<br/>CE + L1 + Confidence"]
T5["FSDP 分布式训练"]
T6["保存检查点<br/>~/checkpoints/deepspec/"]
T1 --> T2 --> T3 --> T4 --> T5 --> T6
end
subgraph Phase3["阶段3: 评估"]
E1["加载训练好的草稿模型"]
E2["逐数据集运行投机解码"]
E3["拒绝采样验证"]
E4["输出指标:<br/>acceptance_length<br/>verify_rate<br/>per-position accept_rates"]
E1 --> E2 --> E3 --> E4
end
Phase1 --> Phase2 --> Phase3
style D5 fill:#fff3e0
style T5 fill:#e1f5fe
style E2 fill:#e8f5e9
6.2 Target Cache 二进制格式
这是整个数据管线的核心设计(代码:deepspec/data/target_cache_dataset.py)。
graph LR
subgraph Manifest["manifest.json"]
M1["版本、样本数、分片数"]
M2["target_layer_ids"]
M3["hidden_size, dtype"]
end
subgraph Index["samples.idx (固定大小)"]
I1["每条记录 64 字节"]
I2["sample_id (8B)"]
I3["shard_id (4B) + seq_len (4B)"]
I4["各字段 offset (8B×5)"]
end
subgraph Shards["shard-00000.bin, ..."]
S1["mmap 直接读取"]
S2["变长样本连续存储"]
end
Manifest --> Index --> Shards
存储内容(每样本):
input_ids:原始 token IDsloss_mask:哪些位置的 target 输出需要计算损失target_hidden_states:指定目标层的隐藏状态(用于 KV 注入)target_last_hidden_states:目标模型最后一层隐藏状态(用于对齐 logits)
7. 训练流程深度剖析
7.1 训练主循环
flowchart TD
Start["torch.multiprocessing.spawn<br/>每 GPU 一个 worker"] --> Init["初始化分布式环境"]
Init --> Build["build_models()"]
Build --> LoadCache["加载 Target Cache<br/>(CacheDataset, mmap)"]
LoadCache --> Loop["训练循环"]
subgraph LoopDetail["每个 batch"]
L1["CUDAPrefetcher 异步 H2D 传输"]
L2["前向传播<br/>sample_anchor_positions()"]
L3["构建 DSpark 注意力掩码"]
L4["_forward_backbone()<br/>并行骨干前向"]
L5["apply_block_logits()<br/>顺序头修正"]
L6["compute_dspark_loss()<br/>联合损失"]
L7["反向传播 + 梯度累积"]
L8{"该同步?"}
L9["clip_grad_norm_ + optimizer.step()"]
L10["记录指标 + 保存检查点"]
L1 --> L2 --> L3 --> L4 --> L5 --> L6 --> L7 --> L8
L8 -->|"是"| L9 --> L10
L8 -->|"否"| L1
end
Loop --> Final["训练完成<br/>clean_up()"]
7.2 联合损失函数详解
代码对应(deepspec/modeling/dspark/loss.py::compute_dspark_loss):
L = α_ce × L_ce + α_l1 × L_l1 + α_conf × L_conf
其中:
| 损失项 | 公式 | 作用 | 权重 (DSpark) |
|---|---|---|---|
| CE Loss | CrossEntropy(draft_logits, target_ids) | 稀疏监督信号,学正确 token | 0.1 |
| L1 Loss | 稠密分布匹配,学整体分布 | 0.9 | |
| Confidence Loss | BCE(confidence_pred, accept_rate*) | 校准置信度头 | 1.0 |
位置衰减权重:
# loss.py::_build_loss_weight_mask
weights_k = exp(-(k-1) / loss_decay_gamma) # γ = 4.0
块内越靠后的位置权重越低,因为后续位置的监督信号天然更不可靠。
7.3 分布式训练要点
- FSDP(
no_shard/full_shard/hybrid_shard) - BF16Optimizer:维护 fp32 master 参数,每步拷贝 bf16 梯度
- StatelessResumableDistributedSampler:确定性 shuffle,支持跨 epoch 连续采样
- 梯度累积:
global_batch_size / (world_size * local_batch_size)
8. 推理与评估流程
8.1 投机解码循环(核心)
这是 deepspec/eval/base_evaluator.py::generate_decoding_sample() 的完整逻辑:
stateDiagram-v2
[*] --> Prefill: 输入 prompt
Prefill --> InitContext: Target Model 前向<br/>提取隐藏状态
InitContext --> Propose: 草稿模型生成候选 token
Propose --> Verify: Target Model 验证
Verify --> AcceptCheck: 拒绝采样
AcceptCheck --> |"全部接受"| Propose
AcceptCheck --> |"部分接受"| SampleResidual: 从残差分布采样
SampleResidual --> Update: 更新 KV Cache<br/>更新隐藏状态
Update --> Propose: 下一轮草稿
AcceptCheck --> |"全部拒绝"| SampleResidual
SampleResidual --> CheckStop: 检查停止 token
CheckStop --> |"未停止"| Update
CheckStop --> |"停止"| [*]
8.2 拒绝采样验证详解
代码对应(base_evaluator.py::verify_draft_tokens):
# 对每个草稿位置 k:
accept_prob_k = min(1, p_target(token_k) / p_draft(token_k))
if random() < accept_prob_k:
accept token_k
else:
reject token_k and all subsequent tokens
# 从残差分布采样下一个 token
next_token ~ max(0, p_target - p_draft)
数学保证:拒绝采样确保输出分布与 target model 分布完全一致(无质量损失)。
8.3 DSpark 推理特化
flowchart TD
Init["_init_context()"] --> |"提取 target hidden states"| Prop["propose()"]
Prop --> Draft["forward_dspark_draft_block()"]
Draft --> Markov["apply Markov Head<br/>(如果启用)"]
Markov --> Conf["预测置信度 cₖ"]
Conf --> Schedule["按置信度截断草稿"]
Schedule --> Verify["verify_draft_tokens()"]
Verify --> Update["_update()<br/>提取新隐藏状态"]
Update --> Prop
9. 关键设计决策
9.1 KV 注入(Context Feature Injection)
DSpark(继承自 DFlash)的核心设计:将目标模型多层隐藏状态注入草稿模型的注意力机制。
graph LR
subgraph Target["目标模型"]
T1["Layer 1 → H₁"]
T2["Layer 9 → H₉"]
T3["Layer 17 → H₁₇"]
T4["Layer 25 → H₂₅"]
T5["Layer 33 → H₃₃"]
end
subgraph Projection["特征投影"]
P1["Concat(H₁,H₉,H₁₇,H₂₅,H₃₃)"]
P2["RMSNorm"]
P3["fc: ℝ^{5D} → ℝ^D"]
end
subgraph DraftAttn["DSpark Attention"]
A1["Q: 仅来自草稿 token"]
A2["K/V: [Target_Context; Draft_Block]"]
A3["注意力掩码:<br/>草稿位置单向可见<br/>context 双向可见"]
end
T1 & T2 & T3 & T4 & T5 --> P1 --> P2 --> P3 --> A2
A1 --> A3
代码对应(qwen3/modeling.py::Qwen3DSparkAttention.forward):
# Q 来自草稿隐藏状态
q = self.q_proj(hidden_states)
# K/V 拼接 target context + 草稿
k = torch.cat([self.k_proj(target_hidden_states),
self.k_proj(hidden_states)], dim=1)
v = torch.cat([self.v_proj(target_hidden_states),
self.v_proj(hidden_states)], dim=1)
# 使用 DSpark 专用注意力掩码(flex_attention)
attn_output = flex_attention(q, k, v, attn_mask=dspark_attn_mask)
9.2 注意力掩码设计
DSpark 的注意力掩码确保每个草稿位置:
- 可以看到 context 中 anchor 之前的全部位置
- 可以看到同 block 内之前的草稿位置(引入块内依赖)
- 不能看到 context 中 anchor 之后的位置(保持因果性)
graph TD
subgraph Mask["DSpark 注意力掩码(单一样本)"]
direction TB
A["Context: [A B C D] (anchor=D)"]
B["Draft Block: [E F G H] (block_size=4)"]
A --> B
M["注意力权限矩阵:<br/>E: 看 A B C D<br/>F: 看 A B C D + E<br/>G: 看 A B C D + E + F<br/>H: 看 A B C D + E + F + G"]
end
9.3 可恢复的分布式采样器
StatelessResumableDistributedSampler 是一个精妙设计:
- 只依赖
start_global_offset_samples和 deterministic seed - 可以从任意位置恢复训练(支持容错)
- 每个 epoch 使用不同 seed 的
randperm,保证覆盖全量数据
10. 总结
10.1 DSpark 的技术创新
mindmap
root((DSpark))
半自回归生成
并行骨干<br/>低延迟
顺序修正头<br/>高质量
Markov/RNN Head<br/>轻量依赖建模
置信度调度
线性置信度头
解析接受率监督
硬件感知前缀调度
贪心截断最大化吞吐
工程实现
Target Cache mmap
FSDP 分布式训练
flex_attention 高效稀疏注意力
Triton 融合损失
10.2 代码阅读路径推荐
| 顺序 | 文件 | 目的 |
|---|---|---|
| 1 | README.md | 项目概览 |
| 2 | config/dspark/dspark_qwen3_4b.py | 理解全部超参数 |
| 3 | deepspec/data/target_cache_dataset.py | 理解数据格式 |
| 4 | deepspec/modeling/dspark/common.py | 锚点采样、注意力掩码 |
| 5 | deepspec/modeling/dspark/qwen3/modeling.py | DSpark 模型完整前向 |
| 6 | deepspec/modeling/dspark/markov_head.py | 顺序修正头实现 |
| 7 | deepspec/modeling/dspark/loss.py | 联合训练目标 |
| 8 | deepspec/trainer/base_trainer.py | 训练主循环 |
| 9 | deepspec/eval/base_evaluator.py | 投机解码 + 拒绝采样 |
| 10 | DSpark_paper.pdf | 论文原始描述 |
10.3 项目意义
DeepSpec 不仅开源了 DSpark 的训练代码,还一并开源了 DFlash 和 Eagle3 的实现,为投机解码研究提供了一个统一的、算法驱动的训练框架。结合 DeepSeek-V4 的生产部署结果(60%-85% 的单用户加速),DSpark 已被验证为当前最先进的投机解码算法之一。