Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

zhaoyingjun/Tiny-R2

Repository files navigation

Tiny-R2: A Hybrid Architecture Combining SWA, CSA, HCA, mHC and DSMoE Under the DeepSeek V4 Design Paradigm

模型结构 benchimark loss OPD训练过程和结果评估

Tiny-R2 模型架构与训练流程文档


📋 目录

  1. 项目概述
  2. 模型架构总览
  3. 核心组件详解
  4. 训练流程
  5. 关键技术特性
  6. 附录:图表索引

项目概述

Tiny-R2 是一个以快速复现DeepSeekV4/R2为目标的项目,目前已经实现如下的架构:

  • 稀疏注意力机制 (HCA-CSA Hybrid Attention)
  • 专家混合模型 (DeepSeek MoE)
  • 超连接技术 (Hyper-Connections)
  • 双优化器策略 (Muon + AdamW)
  • 支持AOPD后训练 (Agent On-policy distillation)目前实测在pubmed数据集上Qwen3.5-0.8B可以提升12%以上准确率,支持行业严肃场景数据通过OPD强化注入。

模型架构总览

模型结构

快速启动

2.1 安装依赖

pip install tiktoken datasets transformers huggingface_hub
pip install git+https://github.com/KellerJordan/Muon
pip install --upgrade transformers
hf auth login --force

2.2 启动训练

2.2.1支持采用Agent进行自主观察和训练调整lr、clip超参,以实现更加稳定的智能化训练;默认不开启;开启后需要使用gemini的api key

python train.py --n_layer 6 --n_embd 1536 --hc 'True' --mhc 'True' --n_experts 8 --max_iters 10000 --attention_types 'Sparse' --batch_size 8 --ctx_len 2048 --hf_dataset 'karpathy/climbmix-400b-shuffle' --resume True --save_best_only True

2.2.2 设置 --use_agent_observe开启Agent智能化训练供能,需要填入你的geimini的api key

python train.py --n_layer 6 --n_embd 1536 --hc 'True' --mhc 'True' --n_experts 8 --max_iters 10000 --attention_types 'Sparse' --batch_size 8 --ctx_len 2048 --hf_dataset 'karpathy/climbmix-400b-shuffle' --resume True --save_best_only True --use_agent_observe True --gemini_api_key "your gemini apikey"

2.3 验证模型训练效果PPL

python evaluate.py --checkpoint checkpoints/best_model_step_xxx.pt 

2.4 启动OPD在线蒸馏(以下三选一即可)

2.4.1使用Qwen3.5-9B模型作为教师模型、Qwen3.5-0.8B作为学生模型进行OPD训练,用RAG增加教师模型,RAG数据集来自问答数据集集medquad,可以复现Readme中的结果;(目前实测在pubmed数据集上Qwen3.5-0.8B可以提升12%以上准确率)

python opd_train.py --dataset pubmed_qa --hf_teacher_model Qwen/Qwen3.5-9B --student_model_name Qwen/Qwen3.5-0.8B --batch_size 2 --grad_accum_steps 4

2.4.2 使用Qwen3.5-9B模型作为教师模型、Qwen3.5-0.8B作为学生模型进行OPD训练,用RAG增加教师模型;--rag_corpus_path外挂RAG数据集,--custom_qa_path 自定义问题集

python opd_train.py --hf_teacher_model Qwen/Qwen3.5-9B --student_model_name Qwen/Qwen3.5-0.8B --batch_size 2 --grad_accum_steps 4 --custom_qa_path baoxianqa.jsonl --rag_corpus_path baoxianqa.txt
 

2.4.3 使用Qwen3.5-9B模型作为教师模型、Tiny-R2作为学生模型进行OPD训练,用RAG增加教师模型;--rag_corpus_path外挂RAG数据集,--custom_qa_path 自定义问题集

python opd_train.py --hf_teacher_model Qwen/Qwen3.5-9B --student_model_name tiny-r2 --student_ckpt checkpoints/best_model_step_40.pt --enable_rag_teacher --batch_size 2 --grad_accum_steps 4 --custom_qa_path baoxianqa.jsonl --rag_corpus_path baoxianqa.txt 

核心组件详解

3.1 注意力机制

Tiny-R2 支持FullAttention与Sparse Attention注意力类型,通过配置 attention_types 灵活切换:

3.1.1 CausalSelfAttention (Full Attention)

标准的因果自注意力机制:

class CausalSelfAttention(nn.Module):
 def __init__(self, config):
 # Projections: Q, K, V from single linear
 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
 self.c_proj = nn.Linear(config.n_embd, config.n_embd)
 
 # Value residual connections
 self.v_residual = config.v_residual
 self.lamb1 = nn.Parameter(torch.tensor(0.5))
 self.lamb2 = nn.Parameter(torch.tensor(0.5))
 
 # Flash Attention support
 self.flash = hasattr(F, "scaled_dot_product_attention")

关键特性:

  • 使用 Flash Attention 加速(如果可用)
  • 支持 Value Residual Connections
  • 标准的因果掩码

3.1.2 HCA-CSA Hybrid Attention

结合 HCA)和CSA 的混合注意力机制。

三种运行模式:

模式 分支配置 说明
HCA [1, 0, 0] 超级压缩分支
SWA [0, 0, 1] 滑动窗口分支
CSA [0, 1, 0] 压缩 + 选择分支

3.2 HCA-CSA 混合注意力

HCA-CSA 是 Tiny-R2 的核心创新之一,通过三个并行分支实现高效的稀疏注意力计算。

架构流程

 Input x
 ↓
┌──────────────────────────────────────────────────────────────┐
│ Query Preparation (HCA style) │
│ compress_q → q_norm → decompress_q → RoPE → Query │
└──────────────────────────────────────────────────────────────┘
 ↓
┌──────────────────┬──────────────────┬──────────────────────┐
│ Branch 1 │ Branch 2 │ Branch 3 │
│ Compression │ Selection │ Sliding Window │
│ (HCA) │ (CSA) │ (SWA) │
├──────────────────┼──────────────────┼──────────────────────┤
│ compress_kv │ importance_score │ window_k/v │
│ kv_norm │ topk selection │ sliding_window │
│ decompress_k/v │ selection_k/v │ RoPE │
│ k_rope │ RoPE │ │
│ K/V Recombine │ K/V Selected │ K/V Window │
└──────────────────┴──────────────────┴──────────────────────┘
 ↓ ↓ ↓
┌──────────────────────────────────────────────────────────────┐
│ Attention Computation │
│ Attention 1: (Q @ K1.T) @ V1 │
│ Attention 2: (Q @ K2.T) @ V2 │
│ Attention 3: (Q @ K3.T) @ V3 │
└──────────────────────────────────────────────────────────────┘
 ↓
branch_gate (Linear + Softmax) → Weighted Sum
 ↓
proj (Linear) → res_dropout → Output

关键参数

# HCA 参数
self.v_head_dim = 32
self.kv_lora_rank = 32
self.q_lora_rank = 3 * self.kv_lora_rank
self.rope_head_dim = 64
self.nope_head_dim = 32
# CSA 参数
self.block_size = config.block_size # Token压缩块大小
self.window_size = config.window_size # 滑动窗口大小
self.num_tokens_to_keep = config.num_tokens_to_keep # 选择保留的token数

3.3 前馈网络与 MoE

3.3.1 MLP

标准的前馈网络,使用 ReLU2 激活函数:

class MLP(nn.Module):
 def __init__(self):
 self.c_fc = nn.Linear(n_embd, 4 * n_embd)
 self.c_proj = nn.Linear(4 * n_embd, n_embd)
 
 def forward(self, x):
 x = self.c_fc(x)
 x = F.relu(x).square() # ReLU squared
 x = self.c_proj(x)
 return x

3.3.2 DSMoE (DeepSeek Mixture of Experts)

DeepSeek 风格的专家混合模型:

Input x [B, T, C]
 ↓
Gate Network (Linear + UnitCenteredNoise)
 ↓
Softmax → Top-k Selection
 ↓
┌─────────────────────────────────────────────────────────────┐
│ Expert Networks │
│ ┌──────────────┐ ┌──────────┐ ┌──────────┐ ┌────────┐ │
│ │ Shared Exp 0 │ │ Expert 1 │ │ Expert 2 │ │ ... │ │
│ │ (Always On) │ │ (Top-k) │ │ (Top-k) │ │ (Top-k)│ │
│ └──────────────┘ └──────────┘ └──────────┘ └────────┘ │
└─────────────────────────────────────────────────────────────┘
 ↓
Weighted Sum of Expert Outputs
 ↓
Output [B, T, C]

关键特性:

特性 说明
Shared Expert 始终激活的共享专家,提供稳定性
Routed Experts Top-k 选择的路由专家
Load Balance Loss 防止专家崩溃的负载均衡损失
Expert Bias 可学习的专家偏置,用于路由优化
UnitCenteredNoise 训练时添加噪声以增加探索

Load Balance Loss 计算:

def moe_load_balance_loss(router_weights, num_experts):
 load = router_weights.sum(dim=0)
 load = load / load.sum()
 ideal = torch.full_like(load, 1.0 / num_experts)
 loss = num_experts * torch.sum((load - ideal) ** 2)
 return loss

3.4 Hyper-Connections

Hyper-Connections 是 Tiny-R2 的另一大创新,通过多流路由机制增强信息流动。

核心概念:

# 初始化 Hyper-Connections
self.init_hc, self.expand_stream, self.reduce_stream = \
 get_init_and_expand_reduce_stream_functions(
 config.hc_num_streams,
 num_fracs=config.hc_num_fracs,
 disable=config.hc_disable,
 )
# 在每个 Block 中使用
self.hc_attn = init_hc(
 dim=config.n_embd,
 branch=self.attn_branch,
 layer_index=index * 2,
 mhc=config.mhc,
 sinkhorn_iters=config.sinkhorn_iters,
 sinkhorn_tau=config.sinkhorn_tau,
)

关键参数:

参数 说明
hc_num_streams 超连接流数量
hc_num_fracs 分段数量
mhc 多超连接配置
sinkhorn_iters Sinkhorn 算法迭代次数
sinkhorn_tau Sinkhorn 温度参数

训练流程

4.1 初始化阶段

Parse Arguments → Update Config → Init WandB → Setup Distributed → Setup AMP

4.2 数据准备

Load HF Dataset (flytech/python-codes-25k)
 ↓
Init GPT2 Tokenizer
 ↓
Create TokenBuffer

TokenBuffer 功能:

  • 流式读取 HuggingFace 数据集
  • 动态填充 token buffer
  • 生成连续的 token batch

4.3 模型初始化

Create Transformer
 ↓
Configure Optimizers (Muon + AdamW)
 ↓
Create LR Scheduler (Warmup + Cosine)

4.4 训练循环

For iter in range(max_iters):
 │
 ├── For step in grad_accum_steps:
 │ ├── Get Batch (TokenBuffer)
 │ ├── Forward Pass (model)
 │ ├── Backward Pass (scaler.scale)
 │ └── Collect Router Weights
 │
 ├── Gradient Clipping (clip_grad_norm_)
 ├── Optimizer Steps (Muon + AdamW)
 ├── Update Scaler (scaler.update)
 ├── LR Scheduler Step
 ├── Update Expert Biases (load balancing)
 └── Log Metrics (WandB)

4.5 评估与保存

If iter % eval_interval == 0:
 ├── Estimate Loss (eval mode)
 ├── Save Checkpoint (if val_loss < 5.27)
 └── Log to WandB

4.6 优化器配置

Tiny-R2 使用双优化器策略:

def configure_optimizers(self, weight_decay, learning_rate, device):
 muon_params = [] # ≥2D parameters in blocks
 adamw_params = [] # Other parameters
 
 for name, param in self.named_parameters():
 if 'blocks' in name and param.ndim >= 2:
 muon_params.append(param)
 else:
 adamw_params.append(param)
 
 return [
 Muon(muon_params, lr=0.02, momentum=0.95),
 torch.optim.AdamW(adamw_params, lr=learning_rate, 
 betas=(0.90, 0.95), weight_decay=weight_decay)
 ]

关键技术特性

5.1 注意力机制对比

特性 CausalSelfAttention HCA-CSA Hybrid
计算复杂度 O(n2) O(n) ~ O(n log n)
内存使用
适用场景 短序列 长序列
分支数量 1 3 (可配置)

5.2 FFN 类型对比

特性 MLP DSMoE
参数量 固定 共享 + 路由
计算量 固定 稀疏激活
表达能力 标准 更强
训练稳定性 需要负载均衡

5.3 核心配置参数

# 模型架构
n_embd = 512 # 嵌入维度
n_head = 8 # 注意力头数
n_layer = 8 # 层数
n_experts = 8 # 专家数量
num_exp = 2 # 每token激活的专家数
# 注意力配置
attention_types = ["FULL", "Spares", ...] # 每层注意力类型
attention_mode = ["FULL", "SWA", "CSA"] # 稀疏注意力模式
# Hyper-Connections
hc = True # 启用超连接
hc_num_streams = 4 # 流数量
# 训练
batch_size = 32
ctx_len = 512 # 上下文长度
lr = 1e-3
warmup_iters = 1000
max_iters = 100000

附录:图表索引

本文档配套图表保存在 /mnt/okcomputer/output/ 目录:

文件名 说明
model_architecture.png 模型整体架构图
loss.png 训练 20亿Tokens的loss收敛图
benchmark.png wikitext-103 benchmark图

参考资料


About

Tiny-R2: A hybrid architecture integrating SWA, CSA, HCA, mHC, and DSMoE under the DeepSeek V4 design paradigm, enabling single-GPU OPD post-training.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

Contributors

AltStyle によって変換されたページ (->オリジナル) /