Skip to content

Ling3-Tiny sglang-jax 推理集成

字段
作者@yuhao
日期2026-05-14
状态Draft
目标在 sglang-jax 接入 Ling3-Tiny 推理,基于 Engine API 对齐蚂蚁内部 4-shot GSM-8K(约 0.85)与蚂蚁内部 MMLU-pro(约 0.44

1. 背景

Ling3-Tiny 是一个 24 层的混合注意力 MoE 模型,模型结构将命名为 BailingMoeV3ForCausalLMmodel_typebailing_hybrid,目标是在 sglang-jax 中支持它的推理路径。

本文只做推理接入:

  • 新增 BailingMoeV3ConfigBailingMoeV3ForCausalLM
  • 接入 Ling3-Tiny 的 KDA / MLA 层结构、MoE、权重映射和 tokenizer / EOS。
  • 显式跳过 ckpt 中的 MTP 层,不实现 MTP 推理或投机解码。
  • 不做训练,不做离线 ckpt 转换。

本文中的 ckpt 指:

text
gs://inference-model-storage-poc-tpu-hns/Ling-3-tiny/model_ckpt/

1.1 模型结构

以下事实来自 ckpt config.json 和 safetensors header:

主干层数24
hidden size1536
attention heads16
layer group每 4 层一组:KDA, KDA, KDA, MLA
KDA head_dim128
MLA q/k dimqk_nope_head_dim=128 + qk_rope_head_dim=64 = 192
MLA v_head_dim128
Q-LoRAq_lora_rank=256
KV-LoRAkv_lora_rank=512
MLA output gatehead-wise sigmoid gate
KDAno_kda_lora=true, kda_safe_gate=true, kda_lower_bound=-5.0
MoE128 routed experts, top-8, 1 shared expert
first dense layerfirst_k_dense_replace=1,layer 0 是 dense MLP
routersigmoid grouped-topk, n_group=8, topk_group=4, routed_scaling_factor=2.5
router dtyperouter_dtype="fp32"
RoPErope_theta=10000, max_position_embeddings=8192, rope_interleave=true
EOSeos_token_id=156892

后续 §3.2 会进一步区分 linear_attn_config["head_dim"]=128ModelConfig.head_dim=192 这两个同名字段。

主干层分布:

0-based layer类型用 RoPE
0, 1, 2KDA
3MLA
4, 5, 6KDA
7MLA
8, 9, 10KDA
11MLA
12, 13, 14KDA
15MLA
16, 17, 18KDA
19MLA
20, 21, 22KDA
23MLA
24MTP推理跳过

等价列表:

text
KDA layers, 1-based:       [1,2,3, 5,6,7, 9,10,11, 13,14,15, 17,18,19, 21,22,23]
Full-attn layers, 1-based: [4,8,12,16,20,24]
MLA layers, 0-based:       [3,7,11,15,19,23]

注意:full_attn_layers 里的 24 是 1-based 主干第 24 层,也就是 ckpt model.layers.23。ckpt model.layers.24 是 MTP。

2. 文件依赖与修改

本文复用以下已有组件:

组件位置用途
KDA kernelsrt/kernels/kda/KDA prefill kernel,已支持 lower_bound 分支
KDA backendsrt/layers/attention/linear/kda_backend.py连接 RadixLinearAttention 与 KDA kernel
线性注意力抽象srt/layers/radix_linear_attention.pyKDA layer runtime 容器
KDA 基础层srt/models/kimi_linear.py:KimiDeltaAttention复用 q/k/v、conv、state、norm、projection 结构
MLA 基础层srt/models/deepseek_v3.py:DeepseekV3Attention复用 MLA Q-LoRA / KV-LoRA / RoPE / attention
MoE 组件srt/models/bailing_moe.pysrt/layers/gate.py复用 router dtype、grouped-topk、expert bias 语义
model registrysrt/models/registry.py通过 EntryClass 发现模型类
HF config registrysrt/hf_transformers_utils.py通过 _CONFIG_REGISTRY 注册 model_type

本文需要修改或新增:

文件改动
srt/configs/bailing_moe_v3.py新增 BailingMoeV3Config
srt/hf_transformers_utils.py注册 bailing_hybrid -> BailingMoeV3Config
srt/models/deepseek_v3.py行为不变地拆分 DeepseekV3Attention,让 Ling3 可以在 o_proj 前插 gate
srt/models/bailing_moe_v3.py新增 Ling3-Tiny 模型、decoder、BailingMLABailingKDAAttention、权重映射
srt/layers/radix_linear_attention.pyRadixLinearAttention.__init__ 新增 lower_bound=None
srt/layers/attention/linear/kda_backend.pyprefill/decode 两条 KDA 路径使用 lower_bound

注:naive.py 本身没有 gate 激活逻辑。lower_bound 分支需要在两处实现并保持数学等价:prefill 走 chunk_kda,由 kernel 内的 kda_gate_chunk_cumsumsrt/kernels/kda/kda.py)处理;decode 走 KDAAttnBackend._fused_kda_gatesrt/layers/attention/linear/kda_backend.py)。后者 docstring 已声明 "Mirrors kda_gate_chunk_cumsum so prefill and decode produce identical activations",本文沿用这一约束。

3. 设计方案

3.1 Config 与注册

新增 BailingMoeV3Config,字段名以 ckpt 为准。需要补以下别名,供现有 sglang-jax 组件读取:

ckpt 字段sglang-jax 字段用途
num_experts_per_toknum_experts_per_tokenMoE top-k
n_groupnum_expert_groupgrouped routing
norm_topk_probmoe_renormalizetop-k 权重归一化
score_function / scoring_funcmoe_router_activation_funcrouter 激活函数
use_mla_nopemla_use_nope是否跳过 MLA RoPE

kda_safe_gatekda_lower_bound 保留原名。它们不加入 KimiLinearConfig,由 BailingMoeV3Config 读出后传给 BailingKDAAttention;在构造函数 BailingKDAAttention.__init__ 里沿用 sgl-jax KDA layer 的参数名 safe_gate / lower_bound

BailingKDAAttention 是新建的 nnx.Module,不继承现有 KDAAttention,attention 部分直接持有一个 RadixLinearAttention 实例。这样 KimiLinearConfig 的字段不会被 BailingMoeV3Config 反向约束,避免出现 ckpt 字段缺失或别名冲突。

注册方式:

  • srt/hf_transformers_utils.py_CONFIG_REGISTRY 中注册 bailing_hybrid -> BailingMoeV3Config
  • srt/models/bailing_moe_v3.py 中设置 EntryClass = BailingMoeV3ForCausalLM,由 model registry 自动发现。

3.2 linear_attn_confighead_dim

运行时用 hf_config.linear_attn_config is not None 判断是否启用混合注意力路径。因此 config 初始化时必须生成:

python
kda_layers = [i for i in range(1, num_hidden_layers + 1) if i % layer_group_size != 0]
full_attn_layers = [i for i in range(1, num_hidden_layers + 1) if i % layer_group_size == 0]

self.linear_attn_config = {
    "kda_layers": kda_layers,                          # 1-based
    "full_attn_layers": full_attn_layers,              # 1-based
    "head_dim": head_dim,                              # KDA: 128
    "num_heads": num_attention_heads,                  # 16
    "short_conv_kernel_size": short_conv_kernel_size,  # 4
}

BailingMoeV3ForCausalLM 还需要实现 patch_model_config

python
@classmethod
def patch_model_config(cls, mc: ModelConfig) -> None:
    mc.attention_arch = AttentionArch.MLA
    mc.head_dim = mc.hf_text_config.qk_nope_head_dim + mc.hf_text_config.qk_rope_head_dim

这里有两个 head_dim,不能混用:

字段用途
linear_attn_config["head_dim"]128KDA q/k/v projection 和 recurrent state
ModelConfig.head_dim192MLA q/k 拼接维度(qk_nope_head_dim + qk_rope_head_dim),由 ModelRunner 推导 full-attention KV pool 形状时使用

ModelConfig.head_dim=192 不是 absorbed MLA 内部 latent KV 维度。absorbed 路径里 attn_mqa.head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576,由 DeepseekV3Attention 自己构造。

3.3 MLA:在 o_proj 前加 head-wise gate

Ling3 MLA 复用 DeepseekV3Attention 的主体流程,只在 o_proj 前多乘一个 head-wise gate。

本节 shape 约定:

text
T = 当前 batch flatten 后的 token 数
H = num_attention_heads = 16
R = kv_lora_rank = 512
N = qk_nope_head_dim = 128
P = qk_rope_head_dim = 64
V = v_head_dim = 128

Ling3 MLA 前向流程

BailingMLA 的前半段有三条支路,其中 Q/KV 两条复用 DeepseekV3Attention

text
hidden [T,1536]
├─ Q path
│  └─ q_a_proj -> [T,256]
│     -> q_a_layernorm
│     -> q_b_proj -> [T,H*(N+P)] = [T,3072]
│     -> reshape [T,H,192]
│     -> split:
│        q_nope [T,H,N] = [T,16,128]
│        q_rope [T,H,P] = [T,16,64]

├─ KV path
│  └─ kv_a_proj_with_mqa -> [T,R+P] = [T,576]
│     -> split:
│        compressed_kv [T,R] = [T,512]
│        k_rope_raw   [T,P] = [T,64]
│     -> kv_a_layernorm(compressed_kv) -> [T,512]
│     -> reshape k_rope_raw -> k_rope [T,1,64]

└─ Ling3 gate path
   └─ g_proj -> [T,H] = [T,16]
      -> sigmoid in fp32 -> gate [T,16]

数据流(实现细节见后面「实现流程」),后半段按下面的顺序汇合:

text
q_rope [T,16,64], k_rope [T,1,64]
  -> RoPE with rope_interleave=true / is_neox_style=False

q_nope [T,16,128]
q_rope [T,16,64]
compressed_kv [T,512]
k_rope [T,1,64]
  -> MLA attention core
  -> pre_o_proj [T,2048]

pre_o_proj [T,2048]
  -> reshape [T,16,128]
  -> multiply gate[..., None]       # gate [T,16] -> [T,16,1]
  -> flatten [T,2048]
  -> o_proj [T,1536]

实现流程

DeepseekV3Attention 的 forward 拆成两个 protected method:

text
_attention_core(positions, hidden_states, forward_batch, token_to_kv_pool) -> (pre_o_proj, kv_fused)
_apply_o_proj(pre_o_proj, hidden_states) -> output

_attention_core 包含 Q path、KV path、layernorm、RoPE、_forward_mqa / _forward_mha 和最后的 flatten;不包含 o_proj。返回的 pre_o_proj 形状必须是 [T,num_heads*v_head_dim] = [T,2048]

DeepseekV3Attention.__call__ 仍按原顺序调用:

text
pre_o_proj, kv_fused = self._attention_core(...)
output = self._apply_o_proj(pre_o_proj, hidden_states)
return output, kv_fused

BailingMLA override _apply_o_proj,只在父类 o_proj 前插入 gate:

text
gate = sigmoid(g_proj(hidden_states))       # [T,16], fp32
pre_o_proj = pre_o_proj.reshape(T,16,128)
pre_o_proj = pre_o_proj * gate[..., None]
pre_o_proj = pre_o_proj.reshape(T,2048)
output = super()._apply_o_proj(pre_o_proj, hidden_states)

这里的设计取舍是:BailingMLA 的 gate 需要原始 hidden_states;如果 override __call__ 会复制整段 DeepseekV3Attention forward,所以只 override _apply_o_proj,并让父子类签名都带上 hidden_states。父类忽略这个参数,子类用它计算 gate。这样 BailingMLA 只新增 gate 逻辑,不复制 DeepseekV3Attention forward;DeepSeek-V3 / DeepSeek-V2 / Kimi-Linear 的数值路径保持不变,由 §4 测试方案里的既有测试覆盖。

RoPE 参数必须显式传:

text
qk_rope_head_dim = 64
max_position_embeddings = 8192
rope_theta = 10000
rope_interleave = true
is_neox_style = False

其中 is_neox_style=False 是 sglang-jax 对 rope_interleave=true 的实现开关

partial_rotary_factor=0.5 在 ckpt 里的语义已经由 qk_rope_head_dim=64 表达(128 * 0.5 = 64)。DeepseekV3Attention 直接使用 qk_rope_head_dim 作为 RoPE 维度,不需要额外传 partial_rotary_factor

3.4 KDA:直连 f_proj/g_proj,保留 gated RMSNorm

Ling3 KDA 与 Kimi-Linear KDA 的主要差异:

Kimi-LinearLing3-Tiny
decay gatef_a_proj -> f_b_projf_proj 直连
output gateg_a_proj -> g_b_projg_proj 直连
output gate 粒度element-wiseelement-wise
output normGatedRMSNorm仍是 GatedRMSNorm 语义
KDA gate lower bound默认 None-5.0

BailingKDAAttention 是新建 class,不继承 KimiDeltaAttention,绝大部分 forward 逻辑从 KimiDeltaAttention copy 后做下面的差异修改,避免在父类里堆 if/else:

  1. 复用 Kimi KDA 的 q/k/v projection、short conv、A_logdt_biasb_projRadixLinearAttentionGatedRMSNormo_proj
  2. f_a_proj/f_b_proj 换成直连 f_proj
  3. g_a_proj/g_b_proj 换成直连 g_proj
  4. output_gate = g_proj(hidden) 后仍走 GatedRMSNorm(o, output_gate),即 RMSNorm(o) * sigmoid(output_gate)
  5. 设置 RadixLinearAttention.lower_bound = -5.0

KDA backend 需要补完整透传链:

text
RadixLinearAttention.__init__(lower_bound=None)
  ├─ prefill: kda_backend._forward_extend(...)
  │             -> chunk_kda(..., lower_bound=layer.lower_bound)
  │                  -> kda_gate_chunk_cumsum(..., lower_bound=...)
  └─ decode:  kda_backend._fused_kda_gate(layer, g)
                -> lower_bound is None: -exp(A_log) * softplus(g + dt_bias)
                -> lower_bound set:     lower_bound * sigmoid(exp(A_log) * (g + dt_bias))

prefill 在 kernel 内、decode 在 JAX 侧分别实现,公式必须保持一致(参见 §2 注)。

默认 lower_bound=None 时,Kimi-Linear 已验证路径应保持不变。

3.5 MoE 与 dense MLP

权重映射见 §3.6;本节只讲 dense / MoE 层的运行时配置。实现上复用 bailing_moe.py 的做法:

  • router_dtype="fp32" 透传给 GateLogit,因此 gate weight 和 expert bias 都用 fp32。
  • TopK._biased_grouped_topk 已满足 Ling3 语义:expert bias 只参与 top-k 选择,实际权重来自未加 bias 的 sigmoid 分数。
  • 复用现有 TopK(routed_scaling_factor=2.5) 通路;权重已在 TopK 内乘过,MoE 输出处不要再乘一次。

3.6 权重映射

建议新模型的模块名尽量贴近 ckpt:word_embeddingsattentionmlp。这样 mapping 主要处理叶子名、转置和少量 reshape。

线性层权重布局约定:

text
PyTorch / HF ckpt: nn.Linear(in_features, out_features).weight = [out_features, in_features]
sglang-jax LinearBase: weight = [input_size, output_size]

因此所有加载到 LinearBase 的 2D linear 权重都需要转置:

text
ckpt [out, in] -> JAX [in, out]

权重映射按 6 类处理。除 MTP skip list 外,加载结束后不允许留下 unmapped ckpt key。

全局:

ckptJAX处理
model.word_embeddings.weightmodel.word_embeddings.embedding直接加载
model.norm.weightmodel.norm.scale直接加载
lm_head.weightlm_head.embedding直接加载

每个 decoder layer 共有:

ckpt key suffixJAX处理
input_layernorm.weightinput_layernorm.scale直接加载
post_attention_layernorm.weightpost_attention_layernorm.scale直接加载

KDA attention 有 13 个权重。注意:ckpt 里 attention.g_proj.weight 的 shape 因 layer 类型而异(KDA [2048,1536] / MLA [16,1536]),loader 必须先判 layer 类型再选映射。

ckpt key suffix处理
q_proj.weight, k_proj.weight, v_proj.weight转置
q_conv1d.weight, k_conv1d.weight, v_conv1d.weight[D,1,K] -> [D,K]
f_proj.weightdecay gate,转置
g_proj.weightelement-wise output gate,shape [2048,1536] -> [1536,2048],输出维度 num_heads * head_dim = 2048
b_proj.weightbeta projection,转置
A_log[H] -> [1,1,H,1]
dt_bias直接加载
o_norm.weight直接加载到 GatedRMSNorm.weight
o_proj.weight转置

MLA attention 有 8 个权重:

ckpt key suffix处理
q_a_proj.weight转置
q_a_layernorm.weight直接加载
q_b_proj.weight转置
kv_a_proj_with_mqa.weight转置
kv_a_layernorm.weight直接加载
kv_b_proj.weight转置;absorbed MLA 会在 post-load 拆成 w_uk/w_uv
g_proj.weightMLA head-wise gate,shape [16,1536] -> [1536,16]
dense.weight映射到 JAX attention.o_proj.weight,shape [1536,2048] -> [2048,1536]

Dense MLP 只出现在 layer 0:

ckpt keyJAX处理
model.layers.0.mlp.gate_proj.weightmodel.layers.0.mlp.gate_proj.weight转置
model.layers.0.mlp.up_proj.weightmodel.layers.0.mlp.up_proj.weight转置
model.layers.0.mlp.down_proj.weightmodel.layers.0.mlp.down_proj.weight转置

MoE 出现在 layer 1..23:

ckpt key patternJAX处理
model.layers.{i}.mlp.gate.weightmoe_gate.kernelrouter weight,转置,fp32
model.layers.{i}.mlp.gate.expert_biasmoe_gate.biasrouter expert bias,直接加载,fp32
model.layers.{i}.mlp.experts.{e}.gate_proj.weightrouted experts gate_proj通过 create_moe_weights_mapping 聚合并转置加载
model.layers.{i}.mlp.experts.{e}.up_proj.weightrouted experts up_proj通过 create_moe_weights_mapping 聚合并转置加载
model.layers.{i}.mlp.experts.{e}.down_proj.weightrouted experts down_proj通过 create_moe_weights_mapping 聚合并转置加载
model.layers.{i}.mlp.shared_experts.gate_proj.weightshared_experts.gate_proj.weight转置
model.layers.{i}.mlp.shared_experts.up_proj.weightshared_experts.up_proj.weight转置
model.layers.{i}.mlp.shared_experts.down_proj.weightshared_experts.down_proj.weight转置

MTP 显式跳过:

  • skip list 覆盖 model.layers.24.attention.*model.layers.24.input_layernorm.*model.layers.24.post_attention_layernorm.*model.layers.24.mlp.*model.layers.24.eh_proj.*model.layers.24.enorm.*model.layers.24.hnorm.*model.layers.24.final_layernorm.*
  • 加载结束后必须 assert unmapped_ckpt_keys ⊆ MTP_SKIP_LIST,并 raise 到主流程,不允许只 log warning。

3.7 Serving

  • tokenizer 使用 HuggingFace tokenizer fast 路径。
  • 从 HF config/tokenizer 读取 eos_token_id=156892,并确保它进入 sampler / stop criteria。
  • 完成本文的 config registry 后,启动命令里的 --trust-remote-code 不改变 tokenizer fast 路径;保留它是为了兼容 ckpt 后续添加自定义 tokenizer/config 的情况。

4. 测试方案

本文复用了大量已验证组件,组件已有单测只能证明组件自身正确,不能证明 Ling3-Tiny 的组装正确。最终验收标准是 §4.2 的 GSM-8K 端到端精度;中间阶段以 §4.1 既有测试和 §3.6 loader assert 兜底。P0 / P2 不引入新 pytest 用例;P2 的 loader sanity 是 loader 内部 runtime assert,不是单独测试。

4.1 P1 既有测试不回归

P1 改动了 DeepseekV3Attention 拆分与 RadixLinearAttention + KDA backend lower_bound 透传,以下既有测试必须全部通过:

测试文件覆盖路径
python/sgl_jax/test/test_mla_attention.pyabsorbed-MLA backend prefill / decode / sliding window / soft cap
python/sgl_jax/test/test_kda_attention.pyKDA backend prefill / decode / mixed
python/sgl_jax/test/test_kda_attention_dp.pyKDA backend DP=4/TP=1、DP=2/TP=2、状态隔离

4.2 P3 端到端精度

评测使用 sglang-jax in-process Engine API,不走 HTTP service / evalscope。数据集统一使用蚂蚁内部版本:GSM-8K 用 4-shot 版(与 qingyuan 报出 0.85 时一致),MMLU-pro 用蚂蚁内部评测版。

评测配置:

TPUGKE TPU v6e 4×4,4 hosts × 4 chips
并行TP=16 / DP=1 / nnodes=4
dtypebf16
promptsgsm_8k.json 预构造蚂蚁内部 4-shot prompts;MMLU-pro 用蚂蚁内部评测版 prompts
generationtemperature=0, max_tokens=2000
parserextract_answer last-numeric parser
输出只有 rank-0 写 predictions jsonl 和打印分数

eval_gsm8k_sgl_jax.pygsm_8k.json 和模型 ckpt 放到每个 pod 可访问的位置,然后 4 个 host 都执行同一批 prompts:

bash
NODE_RANK=$i \
DIST_INIT_ADDR=$RANK0_IP:29500 \
JAX_COMPILATION_CACHE_DIR=<JIT_CACHE_DIR> \
python -u <EVAL_SCRIPT> \
    <MODEL_PATH> \
    <GSM8K_JSON> \
    <OUTPUT_JSONL> \
  2>&1 | tee <LOG_PATH>

Engine(...) 需要固定以下非默认参数:

text
tp_size=16
nnodes=4
page_size=256
disable_overlap_schedule=True
precompile_token_paddings=[16,128,512]
disable_radix_cache=True
chunked_prefill_size=512
mem_fraction_static=0.6
max_running_requests=64

验收标准

数据集期望
GSM-8K(蚂蚁 4-shot)接近参考值 0.85
MMLU-pro(蚂蚁内部版)目标 0.44
稳定性T=0 下连续运行结果应稳定;rank-0 log 出现 正确率: NN/<total> = 0.NNNN 作为完成信号

5. 实施计划

按 PR 顺序拆分,每个 PR 独立可 review:

  • P0 — Config 注册

    • 内容:BailingMoeV3Config_CONFIG_REGISTRY 注册 bailing_hybridlinear_attn_config 派生、patch_model_config(详见 §3.1 / §3.2)
    • 门槛:PR review
  • P1 — 公共基础设施改动

    • 内容:拆分 DeepseekV3Attentiono_proj 前后;补 KDA lower_boundRadixLinearAttention + KDA backend prefill / decode 的透传(详见 §3.3 / §3.4)
    • 门槛:§4.1 既有测试全部通过
  • P2 — Ling3 模型与 loader

    • 内容:BailingMLABailingKDAAttention、decoder / model / LM head、权重映射、MTP 显式跳过(详见 §3.3 – §3.6)
    • 门槛:PR review;ckpt 全量加载必须满足 unmapped_ckpt_keys ⊆ MTP_SKIP_LIST 的 assert(loader sanity,作为 P3 的前置条件)
  • P3 — 端到端精度验证

    • 内容:在 GKE TPU v6e-4×4 通过 in-process Engine API 跑蚂蚁内部 4-shot GSM-8K 和蚂蚁内部 MMLU-pro(详见 §4.2)
    • 门槛:GSM-8K 接近 0.85;MMLU-pro 接近 0.44