RFC-0026: Ling3 KDA Varlen 支持(packing 文档间隔离)
| 字段 | 值 |
|---|---|
| 作者 | @Garrybest |
| 日期 | 2026-05-06 |
| 状态 | Draft |
| 前置 | RFC-0012(Ling3 模型集成)已合并;KDA Pallas 内核接入 segment_ids 参数(独立 PR,由内核同学负责) |
概述
让 Ling3 在 packing 训练场景下,KDA 注意力分支也能感知文档边界("varlen" 语义):MaxText 侧把 decoder_segment_ids 一路透传到 KDA 模块和 Pallas 内核,由内核内部按 segment 重置递推状态 S。
核心设计:Pallas kernel 保持 [B, T, H, D] 接口(向后兼容现有调用方),kernel 内部根据 B 是否为 1 走单样本/批量路径。MaxText 在 wrapper 层用 jax.vmap 把 batch 维剥成单行,再通过 q[None] 扩成 B=1 喂给 kernel,以拿到 vmap 的行间结构性隔离。MLA 分支已支持,本 RFC 仅补齐 KDA 的链路。
背景
当前 packing 状态
scripts/pretrain_ling3_tiny.sh 走 Megatron 离线 packing:每条样本是 seq_len+1=8193 的 token 串,文档之间用 EOD(id=0)分隔。当前脚本设置:
PACKING="true"
RESET_ATTENTION_MASK="false" # ← 即将切到 true
EOD_MASK_LOSS="true"reset_attention_mask=false 时,MegatronSplitInputsTargets (src/maxtext/input_pipeline/input_pipeline_utils.py:1064-1066) 把 inputs_segmentation 写成全 1,attention 退化为单条因果 mask,跨文档可见。这是 Megatron-LM 默认行为,但不是真正的 varlen。
KDA 状态泄漏问题
KDA(KimiDeltaAttention,src/maxtext/layers/attention_kda.py)是线性注意力 + 递推状态:
S_t' = S_{t-1} * exp(g_t)
S_t = S_t' + beta_t · k_t ⊗ v_t不像 MLA 用 seg_q == seg_kv 等值 mask 就能屏蔽跨文档信息——KDA 的状态 S 沿时间维连续累加。即便 inputs_segmentation 携带了真实 doc 边界,只要不在边界处 reset 状态,文档 A 的信息会通过 S 泄漏到文档 B。
MLA 已支持,KDA 被三道墙挡住
| 位置 | 现状 |
|---|---|
src/maxtext/layers/decoders.py:847-852 起 | decoder_segment_ids 已通过 broadcast_args 透传到所有层(包括 Ling3 ScannableBlock) |
src/maxtext/models/ling3.py:262-276(MLA 分支) | 正常透传 decoder_segment_ids 给 MLA |
src/maxtext/models/ling3.py:280-286(KDA 分支) | 显式 drop,注释写明 "KDA does not support packed sequences; drop decoder_segment_ids" |
src/maxtext/layers/attention_kda.py:314-315 | __call__ 签名虽然接受 decoder_segment_ids,但 raise NotImplementedError("KDA does not yet support packed sequences.") |
src/maxtext/kernels/kda/__init__.py:14、src/maxtext/kernels/kda/pallas.py:16、tops.ops.kda.chunk_kda | wrapper 与底层内核没有 segment_ids 参数 |
tests/unit/kda_attention_test.py:212-217 | 反向测试 test_packed_sequences_not_supported 锁死了上述 raise |
设计目标
- 打通
decoder_segment_ids从 train.py 一路到 KDA Pallas 内核的透传链路 - 维持 MLA 分支零回归
- 与 Pallas 内核团队约定的
[B, T, H, D]+segment_ids: [B, T]API contract 严格对齐;由jax.vmap+ B=1 扩维提供行间结构性隔离 - 保持向后兼容:
segment_ids=None时行为与今天完全一致(已开 packing 但 reset_attention_mask=false 的训练不受影响)
前置依赖
| 依赖 | 说明 | 负责方 |
|---|---|---|
KDA Pallas 内核接入 segment_ids | tops.ops.kda.chunk_kda 新增 segment_ids 参数,内核内部在 doc 边界 reset 递推状态 S。详见下文 §3.1 API contract | 内核团队(独立 PR) |
本 RFC 假设上述内核 PR 已就绪或与本 PR 并行交付。MaxText 侧改动可在内核 PR 落地前以 stub 形式提前合入(segment_ids=None 时不传该参数到内核)。
方案
3.1 API contract
3.1.1 Kernel 团队 ↔ MaxText([B, T, H, D] 接口,B=1 等价单样本)
tops.ops.kda.chunk_kda 接受 [B, T, H, D] 形状(保持向后兼容现有调用方)。kernel 内部检查 B 值:B=1 走单样本路径,B>1 走批量路径(内部 flatten 或原生 batch)。MaxText 通过 jax.vmap 把外层 batch 剥掉、再用 q[None] 扩成 B=1,让 kernel 进入单样本路径,从而拿到 vmap 的行间结构性隔离。
| 项 | 值 |
|---|---|
| 参数名 | segment_ids |
| dtype | int32 |
| shape | [B, T](与 q/k/v/g 的 batch + seq 维 [B, T, H, D] 对齐) |
| 取值约定 | 从 1 开始,每个 doc 一个递增 id;0 保留为 padding(当前 MaxText 不会产生 0 segment,但保持兼容) |
segment_ids=None 时的行为 | 内核 fall back 到旧路径(等价于"整段连续累加 S"),与今天行为一致 |
| 边界 reset 语义 | 在 segment_ids[:, t] != segment_ids[:, t-1] 处把递推状态 S 清零;t=0 处永远视为新段起点 |
| MaxText 调用模式 | wrapper 通过 vmap 剥掉 batch + [None] 扩 B=1,kernel 看到 [1, T, H, D] |
3.1.2 MaxText caller ↔ MaxText wrapper(2D 批接口)
调用方(attention_kda.py)不感知 vmap,仍以批为单位传入:
| 项 | 值 |
|---|---|
| 参数名 | segment_ids |
| dtype | int32 |
| shape | [B, T] |
| 取值约定 | 与内核约定一致(从 1 开始,0 = padding) |
segment_ids=None 时的行为 | wrapper 不传 segment_ids 给 kernel,等价 None |
| 行间隔离机制 | jax.vmap + B=1 扩维结构性保证:vmap 把 batch 维剥到外层,每行通过 q[None] 单独喂给 kernel 的 B=1 路径,JAX 编译期保证不同行的中间张量与 kernel 状态互不可见 |
为什么 vmap + B=1 扩维而不是直接让 kernel 处理 B>1:
- 正确性硬保证:vmap 是 JAX 语义层面的"无状态映射",编译器保证 row
的任何中间值(包括 KDA 的递推状态 S)不会污染 row。这比"约定 + kernel 内 batch 循环正确"更可靠。 - 隔离责任剥离 kernel:kernel 的 B>1 路径(flatten / 原生 batch)是否正确处理 segment_ids 跨 row 隔离,需要单独验证。走 B=1 单样本路径,kernel 只关心单条样本里的 doc 边界 reset,逻辑更简单。
- 行间 segment id 撞车不再是问题:row 0 的
segment_ids=[1,1,2,2,3]与 row 1 的segment_ids=[1,1,1,2,2]在 vmap 下分别是两次独立 kernel 调用,数值相同也无法互相 attend。- 未来扩展性:如果将来需要 per-row 不同的
initial_state,vmapin_axes=0自然支持。依赖:Pallas 默认 vmap rule 在 trace 期把 B_outer 维并入 grid 最外层、改写 BlockSpec
index_map,kernel body 不变。q[None]引入的 unit B 维与 vmap 的 B_outer 维合并后,运行期 HLO 与"kernel 直接处理 B>1"几乎一致,无 sequential fallback 或 launch overhead × B 的问题。不需要内核显式标注vmap_method(那是pure_callback的概念,Pallas 不适用)。仅当默认 rule 不适用时,内核团队可用jax.custom_batching.custom_vmap自定义。本 RFC 假设默认 rule 适用。
3.2 涉及文件变更
| 文件 | 变更类型 | 说明 |
|---|---|---|
src/maxtext/kernels/kda/__init__.py | 修改 | chunk_kda wrapper 接受 [B, T, H, D],内部用 jax.vmap 把 batch 剥成单行,再 [None] 扩 B=1 喂给 kernel;新增 segment_ids: [B, T] 参数 |
src/maxtext/kernels/kda/pallas.py | 修改 | pallas_chunk_kda 是 thin wrapper,保持 [B, T, H, D] 形状直接转发到 tops.ops.kda.chunk_kda;新增 segment_ids: [B, T] 参数 |
src/maxtext/layers/attention_kda.py | 修改 | 移除 NotImplementedError;同步 chunk_size padding;在 _shard_map_chunk_kda 中加入 segment_ids 的 pspec 与转发(仍以 2D 调用 chunk_kda,vmap 隐藏在 wrapper 内) |
src/maxtext/models/ling3.py | 修改 | KDA 分支不再 drop decoder_segment_ids,转发给 self.attention(...) |
tests/unit/kda_attention_test.py | 修改 | 反转 test_packed_sequences_not_supported;新增 packed-sequence 正向测试 + 行间隔离测试 |
tests/unit/ling3_decoder_test.py | 新增用例 | 确保 decoder_segment_ids 进入 KDA 路径 |
scripts/pretrain_ling3_tiny.sh | 修改 | RESET_ATTENTION_MASK="true"(独立步骤,启用时再改) |
总改动估计 <100 行有效代码(不含测试与注释)。
3.3 内核 wrapper(src/maxtext/kernels/kda/__init__.py + src/maxtext/kernels/kda/pallas.py)
两个 wrapper 的角色分工:
pallas_chunk_kda:thin wrapper,保持[B, T, H, D]形状,纯转发到tops.ops.kda.chunk_kda。chunk_kda:MaxText 入口(保持 caller 视角不变),内部用jax.vmap把外层 batch 剥成单行,再[None]扩 B=1 后调用pallas_chunk_kda,触发 kernel 单样本路径。
src/maxtext/kernels/kda/pallas.py(thin wrapper, 保持 [B, T, H, D])
def pallas_chunk_kda(
q: jnp.ndarray, # [B, T, H, D_qk]
k: jnp.ndarray, # [B, T, H, D_qk]
v: jnp.ndarray, # [B, T, H, D_v]
g: jnp.ndarray, # [B, T, H, D_qk]
beta: jnp.ndarray, # [B, T, H]
scale: float | None = None,
initial_state: jnp.ndarray | None = None, # [B, H, D_qk, D_v] or None
output_final_state: bool = False,
chunk_size: int = 64,
A_log: jnp.ndarray | None = None, # [H] (无 batch 维)
dt_bias: jnp.ndarray | None = None, # [H * D_qk] (无 batch 维)
use_gate_in_kernel: bool = False,
use_qk_l2norm_in_kernel: bool = False,
safe_gate: bool = False,
lower_bound: float | None = None,
segment_ids: jnp.ndarray | None = None, # [B, T] ← 新增
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
"""Thin wrapper preserving the kernel's [B, T, H, D] API.
The kernel itself dispatches on B: B=1 → single-sample path, B>1 →
flatten / native batch path. MaxText's `chunk_kda` always feeds B=1
via vmap + [None] expansion, so the kernel takes the single-sample path.
"""
o, final_state = tops_chunk_kda(
q=q, k=k, v=v, g=g, beta=beta,
A_log=A_log, dt_bias=dt_bias,
scale=scale, initial_state=initial_state,
output_final_state=output_final_state,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
use_gate_in_kernel=use_gate_in_kernel,
safe_gate=safe_gate, lower_bound=lower_bound,
chunk_size=chunk_size,
segment_ids=segment_ids,
)
return o, final_statesrc/maxtext/kernels/kda/__init__.py(vmap + B=1 扩维)
def chunk_kda(
q: jnp.ndarray, # [B, T, H, D_qk]
k: jnp.ndarray, # [B, T, H, D_qk]
v: jnp.ndarray, # [B, T, H, D_v]
g: jnp.ndarray, # [B, T, H, D_qk]
beta: jnp.ndarray, # [B, T, H]
scale: float | None = None,
initial_state: jnp.ndarray | None = None,
output_final_state: bool = False,
chunk_size: int = 64,
A_log: jnp.ndarray | None = None, # [H]
dt_bias: jnp.ndarray | None = None, # [H * D_qk]
use_gate_in_kernel: bool = False,
use_qk_l2norm_in_kernel: bool = False,
safe_gate: bool = False,
lower_bound: float | None = None,
segment_ids: jnp.ndarray | None = None, # [B, T] or None
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
"""KDA entry point. Wraps the [B, T, H, D] Pallas kernel with `jax.vmap`
for row-level structural isolation.
Implementation: vmap strips the outer batch axis; `_per_row` re-expands
B=1 via [None] before calling the kernel, which triggers the kernel's
single-sample path. Row isolation is guaranteed by JAX's vmap semantics.
"""
def _per_row(q_r, k_r, v_r, g_r, beta_r, seg_r):
# vmap stripped B; re-add B=1 to satisfy kernel's [B, T, H, D] API
q_b, k_b, v_b = q_r[None], k_r[None], v_r[None]
g_b, beta_b = g_r[None], beta_r[None]
seg_b = seg_r[None] if seg_r is not None else None
o, final_state = pallas_chunk_kda(
q=q_b, k=k_b, v=v_b, g=g_b, beta=beta_b,
segment_ids=seg_b,
# 以下参数无 batch 维,vmap 通过 in_axes=None 不做映射
A_log=A_log, dt_bias=dt_bias,
scale=scale, initial_state=initial_state,
output_final_state=output_final_state, chunk_size=chunk_size,
use_gate_in_kernel=use_gate_in_kernel,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
safe_gate=safe_gate, lower_bound=lower_bound,
)
# squeeze B=1 back out so vmap can re-add the outer B
return o[0], (final_state[0] if final_state is not None else None)
# segment_ids=None 时通过 in_axes=None 直接透传 None,避免 vmap 出错
seg_in_axis = 0 if segment_ids is not None else None
o, final_state = jax.vmap(
_per_row,
in_axes=(0, 0, 0, 0, 0, seg_in_axis), # q, k, v, g, beta, segment_ids
)(q, k, v, g, beta, segment_ids)
return o, final_state
tops.ops.kda.chunk_kda侧的对应改动由内核团队在独立 PR 完成。本 RFC 假设其签名接受segment_ids: jax.Array | None、shape=[B, T]、dtype=int32,与pallas_chunk_kda[B, T, H, D]接口对齐。kernel 内部根据 B 是否为 1 dispatch 到单样本 / 批量路径。vmap + B=1 性能:
pl.pallas_call自带默认 vmap rule(trace 期把 batch 维并入 grid + 改写 BlockSpec index_map),无需内核侧任何标注。q[None]引入的 unit B 维与 vmap 的 B_outer 维在 trace 期合并,运行期 HLO 与"kernel 直接吃 B>1"几乎等价,无 sequential fallback、无 launch overhead × B 的问题。仅在默认 rule 不适用时才需要jax.custom_batching.custom_vmap。
3.4 KDA 模块(attention_kda.py)
移除拒绝逻辑
# attention_kda.py:314-315 删除以下两行
if decoder_segment_ids is not None:
raise NotImplementedError("KDA does not yet support packed sequences.")同步 padding 对齐
替换为对齐 padding 的处理(参考已有的 T % chunk_size != 0 分支,行 322-330)。当 KDA 因 chunk_size=64 对 T 做 padding 时,segment_ids 也需要等长 padding,pad 值 0(即 padding token,内核会忽略):
B, T, _ = hidden_states.shape
chunk_size = 64
if T % chunk_size != 0:
pad_len = chunk_size - (T % chunk_size)
hidden_states = jnp.pad(hidden_states, ((0, 0), (0, pad_len), (0, 0)))
if decoder_segment_ids is not None:
# pad 到对齐长度;pad_value=0 表示 padding token(与 1-based segment id 约定相容)
decoder_segment_ids = jnp.pad(decoder_segment_ids, ((0, 0), (0, pad_len)))
T = hidden_states.shape[1]
needs_unpad = True
else:
needs_unpad = False接入 shard_map
segment_ids 与 q/k/v 共享 (activation_batch, activation_norm_length) 这两条 mesh 轴(不需要 head/kv 维),新建一个 pspec:
seg_pspec = self._logical_to_mesh_axes(("activation_batch", "activation_norm_length"))shard_map 把 batch 沿 activation_batch 切成 [B_local, T, ...],然后调 chunk_kda(仍是 2D 入口);vmap 在 chunk_kda 内部把 B_local 维剥成单行喂给 1D kernel。shard_map 的 in_specs 不能容纳 None,所以仍需把"带 segment / 不带 segment"拆成两个 shard_map(推荐方案 A,详见下方备注):
# 公共参数
qkv_pspec = self._logical_to_mesh_axes(self.qkv_axis_names)
beta_pspec = self._logical_to_mesh_axes(self.beta_axis_names)
a_log_pspec = self._logical_to_mesh_axes(("activation_heads",))
dt_bias_2d_pspec = self._logical_to_mesh_axes(("activation_heads", "activation_kv"))
seg_pspec = self._logical_to_mesh_axes(("activation_batch", "activation_norm_length"))
dt_bias_2d = self.dt_bias.value.reshape(self.num_key_heads, self.key_head_dim)
if decoder_segment_ids is None:
# 旧路径:不传 segment_ids
@functools.partial(
jax.shard_map,
mesh=self.mesh,
in_specs=(qkv_pspec, qkv_pspec, qkv_pspec, qkv_pspec,
beta_pspec, a_log_pspec, dt_bias_2d_pspec),
out_specs=qkv_pspec,
check_vma=False,
)
def _kda_no_seg(q, k, v, g, beta, A_log, dt_bias_2d):
o, _ = chunk_kda( # 2D 入口,内部 vmap → 1D kernel
q=q, k=k, v=v, g=g, beta=beta,
A_log=A_log, dt_bias=dt_bias_2d.reshape(-1),
scale=scale, chunk_size=chunk_size,
initial_state=None, output_final_state=False,
use_qk_l2norm_in_kernel=False, use_gate_in_kernel=True,
safe_gate=safe_gate, lower_bound=lower_bound,
)
return o
o = _kda_no_seg(q, k, v, g, beta, self.A_log.value, dt_bias_2d)
else:
# varlen 路径:把 segment_ids 一并喂进 shard_map
@functools.partial(
jax.shard_map,
mesh=self.mesh,
in_specs=(qkv_pspec, qkv_pspec, qkv_pspec, qkv_pspec,
beta_pspec, a_log_pspec, dt_bias_2d_pspec, seg_pspec),
out_specs=qkv_pspec,
check_vma=False,
)
def _kda_with_seg(q, k, v, g, beta, A_log, dt_bias_2d, seg):
o, _ = chunk_kda(
q=q, k=k, v=v, g=g, beta=beta,
A_log=A_log, dt_bias=dt_bias_2d.reshape(-1),
scale=scale, chunk_size=chunk_size,
initial_state=None, output_final_state=False,
use_qk_l2norm_in_kernel=False, use_gate_in_kernel=True,
safe_gate=safe_gate, lower_bound=lower_bound,
segment_ids=seg, # ← [B_local, T],wrapper 内 vmap 拆成 [T]
)
return o
o = _kda_with_seg(q, k, v, g, beta, self.A_log.value, dt_bias_2d, decoder_segment_ids)shard_map None 入参的处理:与上一版相同——
shard_map的in_specs不能直接容纳None,所以拆成两个 shard_map(方案 A);segment_ids=None时连 stub 入参都没有,与今天 100% 等价。嵌套语义:从外到内是
shard_map → chunk_kda(vmap → [None] 扩 B=1 → pallas_chunk_kda → tops_chunk_kda)。shard_map 提供 mesh-level 分片,vmap+[None]提供 batch-level 隔离,kernel 看到[1, T, H, D]走单样本路径。三层职责正交,易于推理。
3.5 Ling3 解码层(ling3.py)
把 ling3.py:262-287 的整个 attention 分发块替换如下,关键差异是 KDA 分支补上 decoder_segment_ids kwarg、并去掉 "KDA does not support packed sequences" 注释:
if isinstance(self.attention, attention_mla.MLA):
attention_output, kv_cache = self.attention(
hidden_states,
hidden_states,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
out_sharding=self.out_sharding,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
else:
# KDA path — same call shape as Ling2's GLA branch.
attention_output, _ = self.attention(
hidden_states,
decoder_positions,
deterministic,
model_mode,
layer_idx=global_layer_idx,
decoder_segment_ids=decoder_segment_ids, # ← 新增;2D [B, T] 透传
)
kv_cache = NoneLing3ScannableBlock.__call__ 已经把 decoder_segment_ids 作为位置参数传给每层 Ling3MoEDecoderLayer(ling3.py:438-447),无需改动。
3.6 RoPE 行为说明
本 RFC 沿用 MaxText 当前的耦合设计:开启 reset_attention_mask=true 后,MegatronSplitInputsTargets (input_pipeline_utils.py:1047-1063) 同时让 inputs_position 按 doc 重置(每个 doc 内部 position 从 0 开始)。这是标准 varlen 行为(FlashAttention varlen + RoPE 通常这么用)。
KDA 自身不消费
decoder_positions(attention_kda.py:308已del decoder_positions),所以 RoPE 行为变化只影响 MLA 分支,与本 RFC 的 KDA 改动正交。
备选方案
| 方案 | 描述 | 否决理由 |
|---|---|---|
| 跳过 vmap,直接 2D 调用 | kernel 已接受 [B, T, H, D],可直接调用让 kernel 内部按 batch 维并行(B>1 路径) | 行间隔离靠"约定 + kernel 内 batch 循环正确",比 vmap+B=1 结构性保证弱;隔离正确性需 kernel 单独证明。已选 vmap+B=1 扩维方案 |
| Gate 注入法 | 不改内核,在 KDA 模块里把 doc 首 token 的门 g 置为 -∞,让 S' = S * exp(g) 在边界归零 | 内核团队已决定在内核内做处理,此方案作废 |
| 分段调用法 | 按 doc 拆开 packed 序列,对每段独立调用 chunk_kda | shape ragged,难批化;性能差;实现复杂 |
维持 reset_attention_mask=false 不上 varlen | 不做改动 | 与 Ling3 训练对齐 Megatron 参考的目标冲突;KDA 跨文档泄漏长期存在 |
影响范围
| 模块 | 影响 |
|---|---|
Ling3 训练(pretrain_ling3_tiny.sh) | 启用 RESET_ATTENTION_MASK=true 后,KDA 与 MLA 都按 doc 隔离;初期 loss 相对 baseline 会有变化(更接近 Megatron 参考) |
| Ling2 训练 | 零影响(Ling2 用 GLA,不走 KDA 路径) |
| KDA 单元测试 | 反向测试需要替换为正向测试 |
| Ling3 解码层测试 | 新增 segment_ids 透传路径覆盖 |
| 推理路径 | 当前 KDA 仍 raise NotImplementedError("KDA autoregressive mode not yet implemented.")(attention_kda.py:317-318),不在本 RFC 范围 |
| Checkpoint | 不影响参数 shape,已有 ckpt 直接续训可用 |
实施计划
| Phase | 内容 | 依赖 |
|---|---|---|
| P0 | 与内核同学最终 confirm §3.1.1 的 [B, T, H, D] kernel 接口 + segment_ids: [B, T] 形状;dump jax.vmap(pallas_chunk_kda)(内部 B=1 扩维) 的 HLO,与"kernel 直接吃 B>1"的 HLO 对比验证 Pallas 默认 vmap rule + B=1 路径等价、无 sequential fallback;定稿 API contract | — |
| P1 | wrapper PR:kernels/kda/__init__.py 加 chunk_kda 的 vmap 包装 + segment_ids 参数;pallas.py 切到 1D 接口 | P0 |
| P2 | KDA 模块 PR:移除 NotImplementedError,加 padding 与 shard_map 透传;ling3.py 去掉 drop | P1;内核侧 1D tops.ops.kda.chunk_kda 接受 segment_ids |
| P3 | 测试 PR:反转 test_packed_sequences_not_supported,新增正向测试 + 行间隔离测试 + ling3 集成测试 | P2 |
| P4 | 启用:scripts/pretrain_ling3_tiny.sh 切到 RESET_ATTENTION_MASK="true",跑 STEPS=20 烟测 + 100 步 baseline 对比 + vmap 性能 benchmark | P3 |
P1 可在内核 PR 未合入前先以 stub 形式 land(向后兼容,不改变行为)。P2 需要内核侧能力就绪。
测试方案
单元测试(tests/unit/kda_attention_test.py)
def test_packed_sequences_supported(self, mesh):
"""传入 segment_ids 不再 raise,输出 shape 正确。"""
def test_segment_ids_padding_alignment(self, mesh):
"""T 不能被 chunk_size=64 整除时,segment_ids 与 hidden_states 同步 pad。"""
def test_packed_equivalence_with_per_doc(self, mesh):
"""数值等价性(核心):
pack(doc_a, doc_b) + segment_ids → 输出
concat(run(doc_a), run(doc_b)) → 输出
两者在 fp32 下应近似 bit-exact(容忍 1e-5)。
需要内核侧 doc 边界 reset 实现正确。"""
def test_segment_ids_none_fallback(self, mesh):
"""segment_ids=None 时与现有行为完全一致(回归保护)。"""
def test_row_independence_under_vmap(self, mesh):
"""行间隔离硬验证:构造 batch=[row_a, row_b],只改 row_b 的 segment_ids,
断言 row_a 的输出 bit-exact 不变。证明 vmap 提供的结构性隔离生效。"""集成测试(tests/unit/ling3_decoder_test.py 新增)
def test_kda_branch_receives_segment_ids(self, mesh):
"""构造一个 KDA 位置的 layer,断言 decoder_segment_ids 不被 drop(mock attention 检查入参)。"""
def test_ling3_full_layer_with_segment_ids(self, mesh):
"""整个 Ling3GenericLayer 端到端跑 packed 输入,shape 正确,无 NaN。"""端到端烟测
| 步骤 | 验证项 |
|---|---|
STEPS=20 RESET_ATTENTION_MASK=true bash scripts/pretrain_ling3_tiny.sh | 不抛异常;step time 合理(KDA 内核加 segment 处理后预计 < 5% 退化) |
data["inputs_segmentation"][0] 打印检查 | 形态为 [1,1,...,2,2,...,3,...] 而非全 1 |
| reset=true vs reset=false 各跑 100 步 | loss 曲线对比,初期 loss 应略升后收敛更好 |
风险
| 风险 | 影响 | 缓解 |
|---|---|---|
API contract 与内核实现不一致(特别是 §3.1.1 的 [B, T, H, D] shape、segment_ids: [B, T] 形状与 1-based / 0=padding 约定) | 数值错误,可能初期 loss 正常但训练长期 diverge | P0 阶段双方文档化 contract;P3 加数值等价性测试,与 per-doc 独立运行做 bit-level 对比 |
Pallas 默认 vmap rule + B=1 扩维不适用于 KDA kernel(例如 kernel 内部依赖 program_id 顺序的副作用,或 B=1 与 B>1 路径在数值上不一致) | vmap + [None] 后 HLO 与"kernel 直接吃 B>1"不等价,可能数值错误或性能退化 | P0 让内核同学 dump 一次 jax.vmap(pallas_chunk_kda)(内部 B=1 扩维) 的 HLO,与"kernel 直接吃 B>1"的 HLO 对比;同时 bit-level 校验 B=1 与 B>1 路径数值一致;若不等价则切到 jax.custom_batching.custom_vmap 显式定义 |
| KDA 内核加 segment 处理后性能退化 | step time 上升 | 烟测阶段 benchmark;如果退化 >10%,与内核同学 review 是否可优化 |
shard_map 中 segment_ids 的 pspec 与 q/k/v 不一致导致编译错误或正确性问题 | 训练崩溃或 silent 错误 | 单元测试覆盖;首次启用前用小 mesh(如 2x2)跑 sanity check |
| Checkpoint 续训行为变化 | 现有 ling3-tiny-conv ckpt(reset=false 训出)切到 reset=true 后初期 loss 跳变 | 在切换前在 RFC/PR 描述中明确告知;先跑 100 步 baseline 对比,确认是预期跳变而非 bug |
MTP 层(layer_idx >= num_decoder_layers)走 MLA 分支,需要确认 decoder_segment_ids 在 MTP 路径上也正确透传 | MTP loss 异常 | 集成测试覆盖 MTP 路径;本 RFC 范围内不改 MTP 代码 |
向后兼容性
| 保证 | 机制 |
|---|---|
reset_attention_mask=false 训练完全等价 | KDA 模块在 decoder_segment_ids is None 时走旧路径;wrapper segment_ids=None 不传给内核 |
| Ling2 零影响 | Ling2 走 GLA(BailingMoeV2LinearAttention),不经过本 RFC 改动 |
| 现有 ling3 ckpt 直接可加载 | 不改变任何参数 shape 或命名 |
| 现有内核 wrapper 调用方零影响 | segment_ids 是带默认值的 kwarg |
开放问题
- Pallas 默认 vmap rule + B=1 扩维等价性(§3.3):P0 阶段需要内核同学 dump 一次
jax.vmap(pallas_chunk_kda)(内部q[None]扩 B=1) 的 HLO,与"kernel 直接吃 B>1"的 HLO 逐 op 对比,确认 grid 扩展 + BlockSpec index_map 改写后语义等价、无 sequential fallback,且 B=1 路径与 B>1 路径数值一致。若发现差异(例如 kernel 内部依赖program_id顺序的副作用,或 B=1/B>1 走不同代码路径),则切到jax.custom_batching.custom_vmap显式定义 vmap rule。P0 必须确认。 shard_mapNone 入参处理(§3.4):选 A(拆两分支)还是 B(构造全 1 占位)。倾向 A,PR 中可一并讨论。- 数值等价性测试的 tolerance:fp32 下 bit-exact 还是 1e-5 容忍?取决于内核内部是否走完全相同的浮点路径。待内核 PR 落地后定。
- 未来推理路径:KDA 推理(autoregressive prefill / decode)目前未实现。本 RFC 不涉及,但启用 varlen 后推理路径如何处理 segment_ids 需要后续 RFC。