Skip to content

RFC-0026: Ling3 KDA Varlen 支持(packing 文档间隔离)

字段
作者@Garrybest
日期2026-05-06
状态Draft
前置RFC-0012(Ling3 模型集成)已合并;KDA Pallas 内核接入 segment_ids 参数(独立 PR,由内核同学负责)

概述

让 Ling3 在 packing 训练场景下,KDA 注意力分支也能感知文档边界("varlen" 语义):MaxText 侧把 decoder_segment_ids 一路透传到 KDA 模块和 Pallas 内核,由内核内部按 segment 重置递推状态 S

核心设计:Pallas kernel 保持 [B, T, H, D] 接口(向后兼容现有调用方),kernel 内部根据 B 是否为 1 走单样本/批量路径。MaxText 在 wrapper 层用 jax.vmap 把 batch 维剥成单行,再通过 q[None] 扩成 B=1 喂给 kernel,以拿到 vmap 的行间结构性隔离。MLA 分支已支持,本 RFC 仅补齐 KDA 的链路。

背景

当前 packing 状态

scripts/pretrain_ling3_tiny.sh 走 Megatron 离线 packing:每条样本是 seq_len+1=8193 的 token 串,文档之间用 EOD(id=0)分隔。当前脚本设置:

bash
PACKING="true"
RESET_ATTENTION_MASK="false"   # ← 即将切到 true
EOD_MASK_LOSS="true"

reset_attention_mask=false 时,MegatronSplitInputsTargets (src/maxtext/input_pipeline/input_pipeline_utils.py:1064-1066) 把 inputs_segmentation 写成全 1,attention 退化为单条因果 mask,跨文档可见。这是 Megatron-LM 默认行为,但不是真正的 varlen

KDA 状态泄漏问题

KDA(KimiDeltaAttentionsrc/maxtext/layers/attention_kda.py)是线性注意力 + 递推状态:

text
S_t' = S_{t-1} * exp(g_t)
S_t  = S_t' + beta_t · k_t ⊗ v_t

不像 MLA 用 seg_q == seg_kv 等值 mask 就能屏蔽跨文档信息——KDA 的状态 S 沿时间维连续累加。即便 inputs_segmentation 携带了真实 doc 边界,只要不在边界处 reset 状态,文档 A 的信息会通过 S 泄漏到文档 B。

MLA 已支持,KDA 被三道墙挡住

位置现状
src/maxtext/layers/decoders.py:847-852decoder_segment_ids 已通过 broadcast_args 透传到所有层(包括 Ling3 ScannableBlock)
src/maxtext/models/ling3.py:262-276(MLA 分支)正常透传 decoder_segment_ids 给 MLA
src/maxtext/models/ling3.py:280-286(KDA 分支)显式 drop,注释写明 "KDA does not support packed sequences; drop decoder_segment_ids"
src/maxtext/layers/attention_kda.py:314-315__call__ 签名虽然接受 decoder_segment_ids,但 raise NotImplementedError("KDA does not yet support packed sequences.")
src/maxtext/kernels/kda/__init__.py:14src/maxtext/kernels/kda/pallas.py:16tops.ops.kda.chunk_kdawrapper 与底层内核没有 segment_ids 参数
tests/unit/kda_attention_test.py:212-217反向测试 test_packed_sequences_not_supported 锁死了上述 raise

设计目标

  1. 打通 decoder_segment_ids 从 train.py 一路到 KDA Pallas 内核的透传链路
  2. 维持 MLA 分支零回归
  3. 与 Pallas 内核团队约定的 [B, T, H, D] + segment_ids: [B, T] API contract 严格对齐;由 jax.vmap + B=1 扩维提供行间结构性隔离
  4. 保持向后兼容:segment_ids=None 时行为与今天完全一致(已开 packing 但 reset_attention_mask=false 的训练不受影响)

前置依赖

依赖说明负责方
KDA Pallas 内核接入 segment_idstops.ops.kda.chunk_kda 新增 segment_ids 参数,内核内部在 doc 边界 reset 递推状态 S。详见下文 §3.1 API contract内核团队(独立 PR)

本 RFC 假设上述内核 PR 已就绪或与本 PR 并行交付。MaxText 侧改动可在内核 PR 落地前以 stub 形式提前合入(segment_ids=None 时不传该参数到内核)。

方案

3.1 API contract

3.1.1 Kernel 团队 ↔ MaxText([B, T, H, D] 接口,B=1 等价单样本)

tops.ops.kda.chunk_kda 接受 [B, T, H, D] 形状(保持向后兼容现有调用方)。kernel 内部检查 B 值:B=1 走单样本路径,B>1 走批量路径(内部 flatten 或原生 batch)。MaxText 通过 jax.vmap 把外层 batch 剥掉、再用 q[None] 扩成 B=1,让 kernel 进入单样本路径,从而拿到 vmap 的行间结构性隔离。

参数名segment_ids
dtypeint32
shape[B, T](与 q/k/v/g 的 batch + seq 维 [B, T, H, D] 对齐)
取值约定从 1 开始,每个 doc 一个递增 id;0 保留为 padding(当前 MaxText 不会产生 0 segment,但保持兼容)
segment_ids=None 时的行为内核 fall back 到旧路径(等价于"整段连续累加 S"),与今天行为一致
边界 reset 语义segment_ids[:, t] != segment_ids[:, t-1] 处把递推状态 S 清零;t=0 处永远视为新段起点
MaxText 调用模式wrapper 通过 vmap 剥掉 batch + [None] 扩 B=1,kernel 看到 [1, T, H, D]

3.1.2 MaxText caller ↔ MaxText wrapper(2D 批接口)

调用方(attention_kda.py)不感知 vmap,仍以批为单位传入:

参数名segment_ids
dtypeint32
shape[B, T]
取值约定与内核约定一致(从 1 开始,0 = padding)
segment_ids=None 时的行为wrapper 不传 segment_ids 给 kernel,等价 None
行间隔离机制jax.vmap + B=1 扩维结构性保证:vmap 把 batch 维剥到外层,每行通过 q[None] 单独喂给 kernel 的 B=1 路径,JAX 编译期保证不同行的中间张量与 kernel 状态互不可见

为什么 vmap + B=1 扩维而不是直接让 kernel 处理 B>1

  1. 正确性硬保证:vmap 是 JAX 语义层面的"无状态映射",编译器保证 row i 的任何中间值(包括 KDA 的递推状态 S)不会污染 row j。这比"约定 + kernel 内 batch 循环正确"更可靠。
  2. 隔离责任剥离 kernel:kernel 的 B>1 路径(flatten / 原生 batch)是否正确处理 segment_ids 跨 row 隔离,需要单独验证。走 B=1 单样本路径,kernel 只关心单条样本里的 doc 边界 reset,逻辑更简单。
  3. 行间 segment id 撞车不再是问题:row 0 的 segment_ids=[1,1,2,2,3] 与 row 1 的 segment_ids=[1,1,1,2,2] 在 vmap 下分别是两次独立 kernel 调用,数值相同也无法互相 attend。
  4. 未来扩展性:如果将来需要 per-row 不同的 initial_state,vmap in_axes=0 自然支持。

依赖:Pallas 默认 vmap rule 在 trace 期把 B_outer 维并入 grid 最外层、改写 BlockSpec index_map,kernel body 不变。q[None] 引入的 unit B 维与 vmap 的 B_outer 维合并后,运行期 HLO 与"kernel 直接处理 B>1"几乎一致,无 sequential fallback 或 launch overhead × B 的问题。不需要内核显式标注 vmap_method(那是 pure_callback 的概念,Pallas 不适用)。仅当默认 rule 不适用时,内核团队可用 jax.custom_batching.custom_vmap 自定义。本 RFC 假设默认 rule 适用。

3.2 涉及文件变更

文件变更类型说明
src/maxtext/kernels/kda/__init__.py修改chunk_kda wrapper 接受 [B, T, H, D],内部用 jax.vmap 把 batch 剥成单行,再 [None] 扩 B=1 喂给 kernel;新增 segment_ids: [B, T] 参数
src/maxtext/kernels/kda/pallas.py修改pallas_chunk_kda 是 thin wrapper,保持 [B, T, H, D] 形状直接转发到 tops.ops.kda.chunk_kda;新增 segment_ids: [B, T] 参数
src/maxtext/layers/attention_kda.py修改移除 NotImplementedError;同步 chunk_size padding;在 _shard_map_chunk_kda 中加入 segment_ids 的 pspec 与转发(仍以 2D 调用 chunk_kda,vmap 隐藏在 wrapper 内)
src/maxtext/models/ling3.py修改KDA 分支不再 drop decoder_segment_ids,转发给 self.attention(...)
tests/unit/kda_attention_test.py修改反转 test_packed_sequences_not_supported;新增 packed-sequence 正向测试 + 行间隔离测试
tests/unit/ling3_decoder_test.py新增用例确保 decoder_segment_ids 进入 KDA 路径
scripts/pretrain_ling3_tiny.sh修改RESET_ATTENTION_MASK="true"(独立步骤,启用时再改)

总改动估计 <100 行有效代码(不含测试与注释)。


3.3 内核 wrapper(src/maxtext/kernels/kda/__init__.py + src/maxtext/kernels/kda/pallas.py

两个 wrapper 的角色分工:

  • pallas_chunk_kda:thin wrapper,保持 [B, T, H, D] 形状,纯转发到 tops.ops.kda.chunk_kda
  • chunk_kda:MaxText 入口(保持 caller 视角不变),内部用 jax.vmap 把外层 batch 剥成单行,再 [None] 扩 B=1 后调用 pallas_chunk_kda,触发 kernel 单样本路径。

src/maxtext/kernels/kda/pallas.py(thin wrapper, 保持 [B, T, H, D]

python
def pallas_chunk_kda(
    q: jnp.ndarray,                          # [B, T, H, D_qk]
    k: jnp.ndarray,                          # [B, T, H, D_qk]
    v: jnp.ndarray,                          # [B, T, H, D_v]
    g: jnp.ndarray,                          # [B, T, H, D_qk]
    beta: jnp.ndarray,                       # [B, T, H]
    scale: float | None = None,
    initial_state: jnp.ndarray | None = None,  # [B, H, D_qk, D_v] or None
    output_final_state: bool = False,
    chunk_size: int = 64,
    A_log: jnp.ndarray | None = None,        # [H]            (无 batch 维)
    dt_bias: jnp.ndarray | None = None,      # [H * D_qk]     (无 batch 维)
    use_gate_in_kernel: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    segment_ids: jnp.ndarray | None = None,  # [B, T]         ← 新增
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
  """Thin wrapper preserving the kernel's [B, T, H, D] API.

  The kernel itself dispatches on B: B=1 → single-sample path, B>1 →
  flatten / native batch path. MaxText's `chunk_kda` always feeds B=1
  via vmap + [None] expansion, so the kernel takes the single-sample path.
  """
  o, final_state = tops_chunk_kda(
      q=q, k=k, v=v, g=g, beta=beta,
      A_log=A_log, dt_bias=dt_bias,
      scale=scale, initial_state=initial_state,
      output_final_state=output_final_state,
      use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
      use_gate_in_kernel=use_gate_in_kernel,
      safe_gate=safe_gate, lower_bound=lower_bound,
      chunk_size=chunk_size,
      segment_ids=segment_ids,
  )
  return o, final_state

src/maxtext/kernels/kda/__init__.py(vmap + B=1 扩维)

python
def chunk_kda(
    q: jnp.ndarray,                          # [B, T, H, D_qk]
    k: jnp.ndarray,                          # [B, T, H, D_qk]
    v: jnp.ndarray,                          # [B, T, H, D_v]
    g: jnp.ndarray,                          # [B, T, H, D_qk]
    beta: jnp.ndarray,                       # [B, T, H]
    scale: float | None = None,
    initial_state: jnp.ndarray | None = None,
    output_final_state: bool = False,
    chunk_size: int = 64,
    A_log: jnp.ndarray | None = None,        # [H]
    dt_bias: jnp.ndarray | None = None,      # [H * D_qk]
    use_gate_in_kernel: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    segment_ids: jnp.ndarray | None = None,  # [B, T] or None
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
  """KDA entry point. Wraps the [B, T, H, D] Pallas kernel with `jax.vmap`
  for row-level structural isolation.

  Implementation: vmap strips the outer batch axis; `_per_row` re-expands
  B=1 via [None] before calling the kernel, which triggers the kernel's
  single-sample path. Row isolation is guaranteed by JAX's vmap semantics.
  """

  def _per_row(q_r, k_r, v_r, g_r, beta_r, seg_r):
    # vmap stripped B; re-add B=1 to satisfy kernel's [B, T, H, D] API
    q_b, k_b, v_b = q_r[None], k_r[None], v_r[None]
    g_b, beta_b = g_r[None], beta_r[None]
    seg_b = seg_r[None] if seg_r is not None else None
    o, final_state = pallas_chunk_kda(
        q=q_b, k=k_b, v=v_b, g=g_b, beta=beta_b,
        segment_ids=seg_b,
        # 以下参数无 batch 维,vmap 通过 in_axes=None 不做映射
        A_log=A_log, dt_bias=dt_bias,
        scale=scale, initial_state=initial_state,
        output_final_state=output_final_state, chunk_size=chunk_size,
        use_gate_in_kernel=use_gate_in_kernel,
        use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        safe_gate=safe_gate, lower_bound=lower_bound,
    )
    # squeeze B=1 back out so vmap can re-add the outer B
    return o[0], (final_state[0] if final_state is not None else None)

  # segment_ids=None 时通过 in_axes=None 直接透传 None,避免 vmap 出错
  seg_in_axis = 0 if segment_ids is not None else None
  o, final_state = jax.vmap(
      _per_row,
      in_axes=(0, 0, 0, 0, 0, seg_in_axis),  # q, k, v, g, beta, segment_ids
  )(q, k, v, g, beta, segment_ids)
  return o, final_state

tops.ops.kda.chunk_kda 侧的对应改动由内核团队在独立 PR 完成。本 RFC 假设其签名接受 segment_ids: jax.Array | Noneshape=[B, T]dtype=int32,与 pallas_chunk_kda [B, T, H, D] 接口对齐。kernel 内部根据 B 是否为 1 dispatch 到单样本 / 批量路径。

vmap + B=1 性能pl.pallas_call 自带默认 vmap rule(trace 期把 batch 维并入 grid + 改写 BlockSpec index_map),无需内核侧任何标注。q[None] 引入的 unit B 维与 vmap 的 B_outer 维在 trace 期合并,运行期 HLO 与"kernel 直接吃 B>1"几乎等价,无 sequential fallback、无 launch overhead × B 的问题。仅在默认 rule 不适用时才需要 jax.custom_batching.custom_vmap


3.4 KDA 模块(attention_kda.py

移除拒绝逻辑

python
# attention_kda.py:314-315 删除以下两行
if decoder_segment_ids is not None:
  raise NotImplementedError("KDA does not yet support packed sequences.")

同步 padding 对齐

替换为对齐 padding 的处理(参考已有的 T % chunk_size != 0 分支,行 322-330)。当 KDA 因 chunk_size=64 对 T 做 padding 时,segment_ids 也需要等长 padding,pad 值 0(即 padding token,内核会忽略):

python
B, T, _ = hidden_states.shape

chunk_size = 64
if T % chunk_size != 0:
  pad_len = chunk_size - (T % chunk_size)
  hidden_states = jnp.pad(hidden_states, ((0, 0), (0, pad_len), (0, 0)))
  if decoder_segment_ids is not None:
    # pad 到对齐长度;pad_value=0 表示 padding token(与 1-based segment id 约定相容)
    decoder_segment_ids = jnp.pad(decoder_segment_ids, ((0, 0), (0, pad_len)))
  T = hidden_states.shape[1]
  needs_unpad = True
else:
  needs_unpad = False

接入 shard_map

segment_idsq/k/v 共享 (activation_batch, activation_norm_length) 这两条 mesh 轴(不需要 head/kv 维),新建一个 pspec:

python
seg_pspec = self._logical_to_mesh_axes(("activation_batch", "activation_norm_length"))

shard_map 把 batch 沿 activation_batch 切成 [B_local, T, ...],然后调 chunk_kda(仍是 2D 入口);vmap 在 chunk_kda 内部把 B_local 维剥成单行喂给 1D kernel。shard_mapin_specs 不能容纳 None,所以仍需把"带 segment / 不带 segment"拆成两个 shard_map(推荐方案 A,详见下方备注):

python
# 公共参数
qkv_pspec = self._logical_to_mesh_axes(self.qkv_axis_names)
beta_pspec = self._logical_to_mesh_axes(self.beta_axis_names)
a_log_pspec = self._logical_to_mesh_axes(("activation_heads",))
dt_bias_2d_pspec = self._logical_to_mesh_axes(("activation_heads", "activation_kv"))
seg_pspec = self._logical_to_mesh_axes(("activation_batch", "activation_norm_length"))

dt_bias_2d = self.dt_bias.value.reshape(self.num_key_heads, self.key_head_dim)

if decoder_segment_ids is None:
  # 旧路径:不传 segment_ids
  @functools.partial(
      jax.shard_map,
      mesh=self.mesh,
      in_specs=(qkv_pspec, qkv_pspec, qkv_pspec, qkv_pspec,
                beta_pspec, a_log_pspec, dt_bias_2d_pspec),
      out_specs=qkv_pspec,
      check_vma=False,
  )
  def _kda_no_seg(q, k, v, g, beta, A_log, dt_bias_2d):
    o, _ = chunk_kda(                          # 2D 入口,内部 vmap → 1D kernel
        q=q, k=k, v=v, g=g, beta=beta,
        A_log=A_log, dt_bias=dt_bias_2d.reshape(-1),
        scale=scale, chunk_size=chunk_size,
        initial_state=None, output_final_state=False,
        use_qk_l2norm_in_kernel=False, use_gate_in_kernel=True,
        safe_gate=safe_gate, lower_bound=lower_bound,
    )
    return o

  o = _kda_no_seg(q, k, v, g, beta, self.A_log.value, dt_bias_2d)
else:
  # varlen 路径:把 segment_ids 一并喂进 shard_map
  @functools.partial(
      jax.shard_map,
      mesh=self.mesh,
      in_specs=(qkv_pspec, qkv_pspec, qkv_pspec, qkv_pspec,
                beta_pspec, a_log_pspec, dt_bias_2d_pspec, seg_pspec),
      out_specs=qkv_pspec,
      check_vma=False,
  )
  def _kda_with_seg(q, k, v, g, beta, A_log, dt_bias_2d, seg):
    o, _ = chunk_kda(
        q=q, k=k, v=v, g=g, beta=beta,
        A_log=A_log, dt_bias=dt_bias_2d.reshape(-1),
        scale=scale, chunk_size=chunk_size,
        initial_state=None, output_final_state=False,
        use_qk_l2norm_in_kernel=False, use_gate_in_kernel=True,
        safe_gate=safe_gate, lower_bound=lower_bound,
        segment_ids=seg,                       # ← [B_local, T],wrapper 内 vmap 拆成 [T]
    )
    return o

  o = _kda_with_seg(q, k, v, g, beta, self.A_log.value, dt_bias_2d, decoder_segment_ids)

shard_map None 入参的处理:与上一版相同——shard_mapin_specs 不能直接容纳 None,所以拆成两个 shard_map(方案 A);segment_ids=None 时连 stub 入参都没有,与今天 100% 等价。

嵌套语义:从外到内是 shard_map → chunk_kda(vmap → [None] 扩 B=1 → pallas_chunk_kda → tops_chunk_kda)。shard_map 提供 mesh-level 分片,vmap+[None] 提供 batch-level 隔离,kernel 看到 [1, T, H, D] 走单样本路径。三层职责正交,易于推理。


3.5 Ling3 解码层(ling3.py

ling3.py:262-287 的整个 attention 分发块替换如下,关键差异是 KDA 分支补上 decoder_segment_ids kwarg、并去掉 "KDA does not support packed sequences" 注释:

python
if isinstance(self.attention, attention_mla.MLA):
  attention_output, kv_cache = self.attention(
      hidden_states,
      hidden_states,
      decoder_positions,
      decoder_segment_ids=decoder_segment_ids,
      deterministic=deterministic,
      model_mode=model_mode,
      out_sharding=self.out_sharding,
      previous_chunk=previous_chunk,
      page_state=page_state,
      slot=slot,
      kv_cache=kv_cache,
      attention_metadata=attention_metadata,
  )
else:
  # KDA path — same call shape as Ling2's GLA branch.
  attention_output, _ = self.attention(
      hidden_states,
      decoder_positions,
      deterministic,
      model_mode,
      layer_idx=global_layer_idx,
      decoder_segment_ids=decoder_segment_ids,   # ← 新增;2D [B, T] 透传
  )
  kv_cache = None

Ling3ScannableBlock.__call__ 已经把 decoder_segment_ids 作为位置参数传给每层 Ling3MoEDecoderLayerling3.py:438-447),无需改动。


3.6 RoPE 行为说明

本 RFC 沿用 MaxText 当前的耦合设计:开启 reset_attention_mask=true 后,MegatronSplitInputsTargets (input_pipeline_utils.py:1047-1063) 同时让 inputs_position 按 doc 重置(每个 doc 内部 position 从 0 开始)。这是标准 varlen 行为(FlashAttention varlen + RoPE 通常这么用)。

KDA 自身不消费 decoder_positionsattention_kda.py:308del decoder_positions),所以 RoPE 行为变化只影响 MLA 分支,与本 RFC 的 KDA 改动正交。


备选方案

方案描述否决理由
跳过 vmap,直接 2D 调用kernel 已接受 [B, T, H, D],可直接调用让 kernel 内部按 batch 维并行(B>1 路径)行间隔离靠"约定 + kernel 内 batch 循环正确",比 vmap+B=1 结构性保证弱;隔离正确性需 kernel 单独证明。已选 vmap+B=1 扩维方案
Gate 注入法不改内核,在 KDA 模块里把 doc 首 token 的门 g 置为 -∞,让 S' = S * exp(g) 在边界归零内核团队已决定在内核内做处理,此方案作废
分段调用法按 doc 拆开 packed 序列,对每段独立调用 chunk_kdashape ragged,难批化;性能差;实现复杂
维持 reset_attention_mask=false 不上 varlen不做改动与 Ling3 训练对齐 Megatron 参考的目标冲突;KDA 跨文档泄漏长期存在

影响范围

模块影响
Ling3 训练(pretrain_ling3_tiny.sh启用 RESET_ATTENTION_MASK=true 后,KDA 与 MLA 都按 doc 隔离;初期 loss 相对 baseline 会有变化(更接近 Megatron 参考)
Ling2 训练零影响(Ling2 用 GLA,不走 KDA 路径)
KDA 单元测试反向测试需要替换为正向测试
Ling3 解码层测试新增 segment_ids 透传路径覆盖
推理路径当前 KDA 仍 raise NotImplementedError("KDA autoregressive mode not yet implemented.")attention_kda.py:317-318),不在本 RFC 范围
Checkpoint不影响参数 shape,已有 ckpt 直接续训可用

实施计划

Phase内容依赖
P0与内核同学最终 confirm §3.1.1 的 [B, T, H, D] kernel 接口 + segment_ids: [B, T] 形状;dump jax.vmap(pallas_chunk_kda)(内部 B=1 扩维) 的 HLO,与"kernel 直接吃 B>1"的 HLO 对比验证 Pallas 默认 vmap rule + B=1 路径等价、无 sequential fallback;定稿 API contract
P1wrapper PR:kernels/kda/__init__.pychunk_kda 的 vmap 包装 + segment_ids 参数;pallas.py 切到 1D 接口P0
P2KDA 模块 PR:移除 NotImplementedError,加 padding 与 shard_map 透传;ling3.py 去掉 dropP1;内核侧 1D tops.ops.kda.chunk_kda 接受 segment_ids
P3测试 PR:反转 test_packed_sequences_not_supported,新增正向测试 + 行间隔离测试 + ling3 集成测试P2
P4启用:scripts/pretrain_ling3_tiny.sh 切到 RESET_ATTENTION_MASK="true",跑 STEPS=20 烟测 + 100 步 baseline 对比 + vmap 性能 benchmarkP3

P1 可在内核 PR 未合入前先以 stub 形式 land(向后兼容,不改变行为)。P2 需要内核侧能力就绪。


测试方案

单元测试(tests/unit/kda_attention_test.py

python
def test_packed_sequences_supported(self, mesh):
  """传入 segment_ids 不再 raise,输出 shape 正确。"""

def test_segment_ids_padding_alignment(self, mesh):
  """T 不能被 chunk_size=64 整除时,segment_ids 与 hidden_states 同步 pad。"""

def test_packed_equivalence_with_per_doc(self, mesh):
  """数值等价性(核心):
     pack(doc_a, doc_b) + segment_ids → 输出
     concat(run(doc_a), run(doc_b)) → 输出
     两者在 fp32 下应近似 bit-exact(容忍 1e-5)。
     需要内核侧 doc 边界 reset 实现正确。"""

def test_segment_ids_none_fallback(self, mesh):
  """segment_ids=None 时与现有行为完全一致(回归保护)。"""

def test_row_independence_under_vmap(self, mesh):
  """行间隔离硬验证:构造 batch=[row_a, row_b],只改 row_b 的 segment_ids,
     断言 row_a 的输出 bit-exact 不变。证明 vmap 提供的结构性隔离生效。"""

集成测试(tests/unit/ling3_decoder_test.py 新增)

python
def test_kda_branch_receives_segment_ids(self, mesh):
  """构造一个 KDA 位置的 layer,断言 decoder_segment_ids 不被 drop(mock attention 检查入参)。"""

def test_ling3_full_layer_with_segment_ids(self, mesh):
  """整个 Ling3GenericLayer 端到端跑 packed 输入,shape 正确,无 NaN。"""

端到端烟测

步骤验证项
STEPS=20 RESET_ATTENTION_MASK=true bash scripts/pretrain_ling3_tiny.sh不抛异常;step time 合理(KDA 内核加 segment 处理后预计 < 5% 退化)
data["inputs_segmentation"][0] 打印检查形态为 [1,1,...,2,2,...,3,...] 而非全 1
reset=true vs reset=false 各跑 100 步loss 曲线对比,初期 loss 应略升后收敛更好

风险

风险影响缓解
API contract 与内核实现不一致(特别是 §3.1.1 的 [B, T, H, D] shape、segment_ids: [B, T] 形状与 1-based / 0=padding 约定)数值错误,可能初期 loss 正常但训练长期 divergeP0 阶段双方文档化 contract;P3 加数值等价性测试,与 per-doc 独立运行做 bit-level 对比
Pallas 默认 vmap rule + B=1 扩维不适用于 KDA kernel(例如 kernel 内部依赖 program_id 顺序的副作用,或 B=1 与 B>1 路径在数值上不一致)vmap + [None] 后 HLO 与"kernel 直接吃 B>1"不等价,可能数值错误或性能退化P0 让内核同学 dump 一次 jax.vmap(pallas_chunk_kda)(内部 B=1 扩维) 的 HLO,与"kernel 直接吃 B>1"的 HLO 对比;同时 bit-level 校验 B=1 与 B>1 路径数值一致;若不等价则切到 jax.custom_batching.custom_vmap 显式定义
KDA 内核加 segment 处理后性能退化step time 上升烟测阶段 benchmark;如果退化 >10%,与内核同学 review 是否可优化
shard_mapsegment_ids 的 pspec 与 q/k/v 不一致导致编译错误或正确性问题训练崩溃或 silent 错误单元测试覆盖;首次启用前用小 mesh(如 2x2)跑 sanity check
Checkpoint 续训行为变化现有 ling3-tiny-conv ckpt(reset=false 训出)切到 reset=true 后初期 loss 跳变在切换前在 RFC/PR 描述中明确告知;先跑 100 步 baseline 对比,确认是预期跳变而非 bug
MTP 层(layer_idx >= num_decoder_layers)走 MLA 分支,需要确认 decoder_segment_ids 在 MTP 路径上也正确透传MTP loss 异常集成测试覆盖 MTP 路径;本 RFC 范围内不改 MTP 代码

向后兼容性

保证机制
reset_attention_mask=false 训练完全等价KDA 模块在 decoder_segment_ids is None 时走旧路径;wrapper segment_ids=None 不传给内核
Ling2 零影响Ling2 走 GLA(BailingMoeV2LinearAttention),不经过本 RFC 改动
现有 ling3 ckpt 直接可加载不改变任何参数 shape 或命名
现有内核 wrapper 调用方零影响segment_ids 是带默认值的 kwarg

开放问题

  1. Pallas 默认 vmap rule + B=1 扩维等价性(§3.3):P0 阶段需要内核同学 dump 一次 jax.vmap(pallas_chunk_kda)(内部 q[None] 扩 B=1) 的 HLO,与"kernel 直接吃 B>1"的 HLO 逐 op 对比,确认 grid 扩展 + BlockSpec index_map 改写后语义等价、无 sequential fallback,且 B=1 路径与 B>1 路径数值一致。若发现差异(例如 kernel 内部依赖 program_id 顺序的副作用,或 B=1/B>1 走不同代码路径),则切到 jax.custom_batching.custom_vmap 显式定义 vmap rule。P0 必须确认
  2. shard_map None 入参处理(§3.4):选 A(拆两分支)还是 B(构造全 1 占位)。倾向 A,PR 中可一并讨论。
  3. 数值等价性测试的 tolerance:fp32 下 bit-exact 还是 1e-5 容忍?取决于内核内部是否走完全相同的浮点路径。待内核 PR 落地后定。
  4. 未来推理路径:KDA 推理(autoregressive prefill / decode)目前未实现。本 RFC 不涉及,但启用 varlen 后推理路径如何处理 segment_ids 需要后续 RFC。