Skip to content

Varlen Chunk-KDA: Padding 对结果的影响分析

背景

在 varlen(变长序列)模式下,每个序列的长度 L 可能不是 chunk_size(记为 BT,通常 64)的整数倍。例如 L=100, BT=64 时,需要 2 个 chunk:第一个满 64 tokens,第二个只有 36 个真实 token。

FLA 库(Triton 实现,v0.4.2)和本地 Pallas 实现(tops/ops/kda/)采用不同的策略处理这种情况。本分析验证两种策略下,padding 是否会影响真实 token 的计算结果。


一、FLA 库(Triton 实现)

安装路径.venv/lib/python3.12/site-packages/fla/ops/kda/

核心思路:不对输入张量做任何物理填充,每个 chunk 内只处理真实存在的 token。

1.0 公共 API

入口函数chunk_kda() @ fla/ops/kda/chunk.py:144

python
def chunk_kda(
    q, k, v, g, beta,              # [B, T, H, D]
    scale=None,
    initial_state=None,            # [N, H, K, V] 用于 varlen
    output_final_state=False,
    use_qk_l2norm_in_kernel=False,
    use_gate_in_kernel=False,
    cu_seqlens=None,               # [N+1], varlen 模式入口
    cu_seqlens_cpu=None,
    safe_gate=False,
    lower_bound=None,
    disable_recompute=False,
    return_intermediate_states=False,
    cp_context=None,
    transpose_state_layout=False,
    **kwargs,
):

内部使用 torch.autograd.Function

  • ChunkKDAFunction.forward() @ chunk.py:18 → 调用 chunk_kda_fwd()
  • ChunkKDAFunction.backward() @ chunk.py:101 → 调用 chunk_kda_bwd()

Varlen 约定(chunk.py:285-295):当 cu_seqlens is not None 时,要求 B=1T = sum(seq_lens)(所有序列平坦化拼接),initial_state.shape[0] == len(cu_seqlens)-1

1.1 Forward 调用链

text
chunk_kda()                                                    chunk.py:144
  └── ChunkKDAFunction.forward()                               chunk.py:18
        ├── (可选) l2norm_fwd(q) / l2norm_fwd(k)               modules/l2norm.py
        ├── prepare_chunk_indices(cu_seqlens, BT) [纯 PyTorch] utils/index.py:112
        └── chunk_kda_fwd()                                    chunk_fwd.py:16

              ├── [Phase A] 门控预处理
              │   ├── use_gate_in_kernel=True:
              │   │   kda_gate_chunk_cumsum()                  gate.py:415
              │   │     └── Triton: kda_gate_chunk_cumsum_vector_kernel  gate.py:362
              │   └── use_gate_in_kernel=False:
              │       chunk_local_cumsum()                      utils/cumsum.py:429
              │         └── Triton: chunk_local_cumsum_vector_kernel      cumsum.py:86

              ├── [Phase B] Intra-Chunk
              │   chunk_kda_fwd_intra()                         chunk_intra.py:951
              │     ├── safe_gate=True:
              │     │   └── Triton: chunk_kda_fwd_kernel_intra_sub_chunk  chunk_intra.py:817
              │     ├── safe_gate=False:
              │     │   └── Triton: chunk_kda_fwd_kernel_intra_token_parallel  chunk_intra_token_parallel.py:25
              │     ├── Triton: chunk_kda_fwd_kernel_inter_solve_fused       chunk_intra.py:37
              │     └── recompute_w_u_fwd()                     wy_fast.py:210
              │           └── Triton: recompute_w_u_fwd_kda_kernel           wy_fast.py:32
              │     返回: w, u, qg, kg, Aqk, Akk

              ├── [Phase C] 跨 Chunk 状态传播
              │   chunk_gated_delta_rule_fwd_h()                common/chunk_delta_h.py:655
              │     └── Triton: chunk_gated_delta_rule_fwd_kernel_h_blockdim64  chunk_delta_h.py:35
              │     返回: h, v_new, final_state

              └── [Phase D] 输出计算
                  chunk_gla_fwd_o_gk()                           gla/chunk.py:863
                    └── Triton: chunk_gla_fwd_kernel_o           gla/chunk.py:308
                  返回: o

1.2 Backward 调用链

text
chunk_kda_bwd()                                                chunk_bwd.py:410

  ├── [S0] 重计算前向中间结果
  │   ├── kda_gate_chunk_cumsum() (如需要)                      gate.py:415
  │   ├── recompute_w_u_fwd()                                   wy_fast.py:210
  │   └── chunk_gated_delta_rule_fwd_h()                        common/chunk_delta_h.py:655

  ├── [S1] dAqk + dv
  │   chunk_kda_bwd_dAv()                                       chunk_bwd.py:294
  │     └── Triton: chunk_kda_bwd_kernel_dAv                    chunk_bwd.py:42

  ├── [S2] 反向状态传播
  │   chunk_gated_delta_rule_bwd_dhu()                          common/chunk_delta_h.py:715
  │     └── Triton: chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64  chunk_delta_h.py:345

  ├── [S3] WY 融合反向
  │   chunk_kda_bwd_wy_dqkg_fused()                             chunk_bwd.py:345
  │     └── Triton: chunk_kda_bwd_kernel_wy_dqkg_fused          chunk_bwd.py:125

  ├── [S4] Intra-chunk 反向
  │   chunk_kda_bwd_intra()                                     chunk_intra.py:1030
  │     └── Triton: chunk_kda_bwd_kernel_intra                  chunk_intra.py:367

  ├── [S5] 反向累积和 + 门控反向
  │   ├── chunk_local_cumsum(reverse=True)                      utils/cumsum.py:429
  │   └── kda_gate_bwd() (如需要)                               gate.py:240

  返回: dq, dk, dv, db, dg, dh0, dA, dbias

1.3 Varlen 索引映射

prepare_chunk_indices() @ fla/ops/utils/index.py:112

python
@tensor_cache
def prepare_chunk_indices(cu_seqlens, chunk_size):
    lens = torch.diff(cu_seqlens)                # 各序列长度, 如 [100, 50]
    n_chunks = triton.cdiv(lens, chunk_size)     # 如 [2, 1]
    indices = torch.cat([torch.arange(n) for n in n_chunks])
    # indices = [0, 1, 0]  (seq0: chunk0, chunk1; seq1: chunk0)
    return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1)
    # → [[0, 0], [0, 1], [1, 0]]  = (seq_id, local_chunk_id)

prepare_chunk_offsets() @ fla/ops/utils/index.py:125 计算跨序列的累积 chunk 数,用于索引 [NT_total, H, K, V] 张量。

1.4 Kernel 内的边界检查

每个 Triton kernel 通过 IS_VARLEN 启发式开关(基于 cu_seqlens is not None)获取当前序列边界:

python
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})

Kernel 内部:

python
if IS_VARLEN:
    i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32)
    bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32)
    T = eos - bos        # 动态覆盖编译时常量 T
else:
    bos, eos = i_b * T, i_b * T + T

所有边界检查汇总:

检查方式位置效果
i_t * BT >= T → returninter_solve_fused 开头整个 chunk 超出序列则跳过
i_ti >= T → returnintra_sub_chunk 开头子块起始越界则跳过
if i_tc1 < T: / i_tc2 < T: / i_tc3 < T:inter_solve_fused K 循环条件加载越界子块
m_tc1 = (i_tc1 + o_i) < Tinter_solve_fused子块内逐元素有效性掩码
m_c = o_c < Tintra_sub_chunk子块内逐元素有效性掩码
for i in range(2, min(BC, T - i_tc0))前向替换循环按有效行数截断
min(BC//2, T - i_ti - 1)g_ref 加载门控参考点不越界
last_idx = min((i_t+1)*BT, T) - 1状态传播chunk 最后有效 token
boundary_check=(0, 1)block pointer load/store阻止真实越界访存;但 load 未指定 padding_option 时,越界 lane 的返回值是 undefined,不是 0

1.4.1 Triton tl.load 越界语义

需要区分两类写法:

python
# 标量/指针 + mask: 越界 lane 明确填 other
tl.load(ptr + offsets, mask=mask, other=0.0)

# block pointer + boundary_check: 越界 lane 不访存,但默认返回 undefined
tl.load(block_ptr, boundary_check=(0, 1))

# block pointer 要明确补零时,需要显式写 padding_option
tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")

因此,boundary_check 只能保证不发生非法内存访问;它本身不保证越界值是 0。是否影响真实结果,取决于这些 undefined lane 是否会进入真实 token / 真实 feature 的 reduction、dot 或递推。

1.4.2 各 API / kernel 的越界 load 影响

API / 阶段代表 kernel越界来源当前处理mask 后的值对结果的影响
l2norm_fwd() 大 D 路径l2norm_fwd_kernel1BD > D 的 feature tailmask=cols < D, other=0.0无效 feature lane 读成 0.0;store mask 不写回无风险
l2norm_bwd() 大 D 路径l2norm_bwd_kernel1BD > D 的 feature tailmask=cols < D, other=0.0y/dy 无效 feature lane 读成 0.0;store mask 不写回无风险
l2norm_fwd() 小 D 路径l2norm_fwd_kernelBD > D 或最后一个 BT block 超过 Tboundary_check,未指定 padding_optionload 越界 lane 是 undefined;store 越界 lane 不写回有风险:影响 fwd 的 rstd / y
l2norm_bwd() 小 D 路径l2norm_bwd_kernelBD > D 或最后一个 BT block 超过 Tboundary_check,未指定 padding_optionload 越界 lane 是 undefined;store 越界 lane 不写回有风险:影响 bwd 的 dx
kda_gate_chunk_cumsum()kda_gate_chunk_cumsum_vector_kernelT tail、S/K tailboundary_check;cumsum 沿 T 维;store 带 boundary_checkload 越界 lane 是 undefined;没有显式改值;store 越界 lane 不写回无风险,仅限 forward cumsum
chunk_local_cumsum()chunk_local_cumsum_vector_kernelT tail、S/K tail与 gate cumsum 类似load 越界 lane 是 undefined;若使用 mask/other 路径则无效 lane 为 0.0;store 越界 lane 不写回无风险,仅限 forward cumsum
chunk_kda_fwd_intra() token-parallel 路径chunk_kda_fwd_kernel_intra_token_parallelT tail、K taili_t >= T → return;循环 j < min(T, ...);K tail 用 m_k 在部分表达式中置 0越界 token 直接 return 或不进入循环;tl.where(m_k, ..., 0.0) 把无效 K lane 写成 0.0有风险:K tail 未完全清零时影响 fwd 的 intra A / o
chunk_kda_fwd_intra() inter solve 路径chunk_kda_fwd_kernel_inter_solve_fused子块 T tail、K tail、Akk/Aqk tail子块起点 if i_tc* < T;行 mask m_tc*;store 带 boundary_check不存在的子块跳过;tl.where(m_tc*[:, None], ..., 0) 把无效 token 行的 gate factor 写成 0tl.load(..., mask=m_k, other=0) 把 K tail 的 reference gate 写成 0;store 越界 lane 不写回有风险:K tail 未补零时影响 fwd 的 Aqk / Akk
recompute_w_u_fwd()recompute_w_u_fwd_kda_kernelT tail、K tail、V tailT tail 通过 last_idxm_t、store boundary 隔离;K/V tail 多为 block-pointer loadtl.where(m_t[:, None], ..., 0) 把无效 token 行的 kg 写成 0tl.load(..., mask=m_k, other=0.) 把 K tail 的 last gate 写成 0.0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回有风险:K tail 未补零时影响 fwd 的 w / kg / qg
chunk_gated_delta_rule_fwd_h()chunk_gated_delta_rule_fwd_kernel_h_blockdim64最后 chunk T tail、K/V block taillast_idx 取最后真实 token;T tail 在 gate 衰减中用 m_t 置 0;store boundarytl.where(m_t, exp(...), 0) 把无效 token 行的衰减乘子写成 0,从而无效行 b_v 变为 0tl.load(..., mask=o_k < K, other=0.) 把 K tail 的 last gate 写成 0.0;store 越界 lane 不写回有风险:K tail 未补零时影响 fwd 的 state h
chunk_gla_fwd_o_gk() 输出chunk_gla_fwd_kernel_oT tail、K tail、V tail、A tailstore boundary;A 只用因果下三角 m_stl.where(m_s, b_A, 0.) 把非因果的 A 上三角写成 0.0;block-pointer tail 是 undefined;store 越界 lane 不写回有风险:K tail 未补零时影响 fwd 的 o
chunk_kda_bwd_dAv()chunk_kda_bwd_kernel_dAvT tail、V tail、A tailm_t/m_A 清掉无效 A;store boundarytl.where(m_A, b_A, 0) 把无效 A entry 写成 0tl.where(o_t[:, None] >= o_t, ..., 0.) 把非因果 dA entry 写成 0.0;store 越界 lane 不写回有风险:V tail 未补零时影响 bwd 的 dA
chunk_kda_bwd_wy_dqkg_fused()chunk_kda_bwd_kernel_wy_dqkg_fusedT tail、K tail、V tailm_t/m_lastlast_idx 和 store boundary;K/V tail多为 block-pointer loadm_t 用于识别有效 token;m_last 只保留最后真实 token;tl.load(..., mask=m_k, other=0) 把 K tail 的 reference gate 写成 0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回有风险:影响 bwd 的 dq / dk / dv / dg / db / dA
prepare_wy_repr_bwd()prepare_wy_repr_bwd_kda_kernelT tail、K tail、V tailstore boundary;部分标量 tail 用 masktl.load(..., mask=m_k, other=0.) 的 K tail 标量读成 0.0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回有风险:影响 bwd 的中间梯度 dw / du
chunk_kda_bwd_intra()chunk_kda_bwd_kernel_intraT tail、K tail、A tail有子块/循环边界和部分 mask=..., other=0mask=..., other=0 的无效 lane 读成 0tl.where(..., 0.) 的无效项写成 0.0;store 越界 lane 不写回有风险:K tail 未补零时影响 bwd 的 intra dA / dq / dk
chunk_local_cumsum(reverse=True)chunk_local_cumsum_vector_kernelT tail、S/K tailreverse cumsum;store boundary若只有 boundary_check,load tail 是 undefined;若使用 mask/other,无效 lane 是 0.0;store 越界 lane 不写回有风险:影响 reverse cumsum 的真实 token 输出
kda_gate_bwd()kda_gate_bwd_kernelT tail、K tailblock-pointer boundary_check,部分 bias load boundaryblock-pointer tail 是 undefined;store 越界 lane 不写回;dA 没有按 T/D tail mask 后再归约有风险:影响 bwd 的 dg / db

常见 mask 改写规则:

写法无效 lane 的值
tl.load(ptr, mask=m, other=0) / other=0.0读成 0 / 0.0
tl.where(m, x, 0) / tl.where(m, x, 0.0)写成 0 / 0.0
tl.where(m_t, exp(...), 0) 作为乘子无效 token 的乘子为 0,对应贡献被清零
因果 mask,如 tl.where(m_s, A, 0.)tl.where(m_A, A, 0)非因果或无效矩阵 entry 变成 0 / 0.0
tl.store(..., mask=m)boundary_check store无效 lane 不写回,不是写成 0
if ...: return / if i_tc < T: / range(..., min(...))对应 chunk、子块或循环迭代不执行,没有产生新值

关键修正:FLA 的 varlen 主路径确实没有物理 padding token,但这不等价于所有 block-pointer 越界 load 都自动补零。T 维 tail 多数被 returnmin(T, ...)m_t/m_tc、三角 mask 和 store boundary 隔离,所以不会回写真实 token;但 K/V/D 这类被 tl.dottl.sum 归约的尾部 lane,如果 block size 大于真实维度且没有显式 other=0 / padding_option="zero" / tl.where(mask, x, 0),就可能污染真实结果。

对当前 KDA 调用链,最明确的问题是可选的 use_qk_l2norm_in_kernel=True:当 head dim Kl2norm_fwd_kernel / l2norm_bwd_kernelK 不是 BD 的整数覆盖边界(例如非 2 的幂)时,undefined feature tail 会直接进入 norm 的 reduction。修复方式是给 block-pointer load 加 padding_option="zero",或改回显式 mask/other=0.0

1.5 具体示例

seq_lens = [100, 50], BT = 64 为例:

python
# 变量初始化
cu_seqlens = torch.tensor([0, 100, 150])  # 2 个序列, T=150

# prepare_chunk_indices 输出:
# chunk_indices = [[0, 0],   ← seq0 的 chunk0 (64 tokens, 全真实)
#                  [0, 1],   ← seq0 的 chunk1 (36 tokens, 无填充)
#                  [1, 0]]   ← seq1 的 chunk0 (50 tokens, 无填充)

# NT = 3 (总 chunk 数), 而如果定长则 NT = ceil(150/64) = 3
# 但 varlen 下每个 chunk 的 T 值不同:
#   seq0-chunk0: T=100, bos=0,   eos=100
#   seq0-chunk1: T=100, bos=0,   eos=100, 仅处理 positions 64-99
#   seq1-chunk0: T=50,  bos=100, eos=150, 仅处理 positions 100-149

Kernel 内部处理(以 chunk_kda_fwd_kernel_inter_solve_fused 为例 @ chunk_intra.py:37):

python
# chunk0 of seq0 (i_t = 0): i_t * BT = 0, T = 100, 0 < 100 ✓
#   处理 4 个子块 (BC=16): tc0=0, tc1=16, tc2=32, tc3=48 — 全部有效

# chunk1 of seq0 (i_t = 1): i_t * BT = 64, T = 100, 64 < 100 ✓
#   子块 tc0=64 (valid), tc1=80 (valid), tc2=96 (valid 但最后4个token越界)
#   tc3=112 ≥ 100 → 跳过整个 tc3
#   在 tc2 中: m_tc2 = (96+o_i) < 100 → 只有 o_i ∈ {0,1,2,3} 有效
#   前向替换: for i in range(2, min(16, 100-96)) = range(2, 4) → 仅处理行 2,3

# chunk0 of seq1 (grid_idx = 2): chunk_indices[2] = [1, 0]
#   解码后 i_n=1, i_t=0,因此使用 seq 内 local chunk 坐标
#   i_t * BT = 0, T = 50, bos=100, eos=150
#   子块 tc0=0, tc1=16, tc2=32, tc3=48 — 全部 < 50,全部会进入条件分支
#   在 tc3 中: m_tc3 = (48+o_i) < 50 → 只有 o_i ∈ {0,1} 有效
#   前向替换: for i in range(2, min(16, 50-48)) = range(2, 2) → 空循环

1.6 结论

FLA 库完全不引入物理 padding 数据。最后一个 chunk 只包含真实 token。 Varlen 的 T 维尾部通常通过 kernel 内的 chunk 边界、行 mask、三角 mask 和 store boundary 隔离,因此不会写回真实 token。

但要注意:tl.load(block_ptr, boundary_check=...) 默认不是补零。只要越界 lane 进入 tl.sum / tl.dot 这类跨 feature 或跨 value 维度的归约,就可能影响真实结果。当前文档中的 FLA 正确性结论应理解为:T 维 varlen tail 在主 KDA 路径中被隔离;K/V/D 维 tail 还需要维度与 block size 对齐,或显式补零/掩码。


二、Pallas 实现(TPU)

实现路径tops/ops/kda/tops/ops/common/

核心思路:通过 _align_seqs 将每个序列填充到 BT 整数倍,利用 padding token 的零值特性使其在数学上透明(L 矩阵块对角解耦)。

2.0 公共 API

入口函数chunk_kda() @ tops/ops/kda/chunk.py:49

python
@jax.custom_vjp
def chunk_kda(
    q, k, v, g, beta,              # [B, T, H, D]
    A_log=None, dt_bias=None,
    scale=None,
    initial_state=None,            # [N, H, K, V] 用于 varlen
    output_final_state=False,
    use_qk_l2norm_in_kernel=False,
    use_gate_in_kernel=False,
    segment_ids=None,              # [B, T] 或 [1, T], varlen 模式入口
    safe_gate=False,
    lower_bound=None,
    disable_recompute=False,
    cp_context=None,
    transpose_state_layout=False,
    chunk_size=64,
    **kwargs,
):

自定义 VJP 绑定 @ chunk.py:428

python
chunk_kda.defvjp(_chunk_kda_fwd_custom, _chunk_kda_bwd_custom)

Varlen 逻辑 chunk.py:88:当 segment_ids is not None 时,调用 segment_ids_to_seqlens() 提取 cu_seqlens;也可以通过 kwargs["cu_seqlens"] 直接传入。下面 Pallas kernel 分析默认讨论 varlen 路径,即 cu_seqlens is not None,非 varlen 只作为共享 wrapper 的背景,不单独展开。

2.1 Forward 调用链(默认 varlen)

text
chunk_kda()                                                   chunk.py:49
  └── custom_vjp:
        _chunk_kda_fwd_custom()                               chunk.py:129

          │ [Varlen 预处理]
          │ ├── segment_ids → cu_seqlens (如需要)             utils.py:193
          │ ├── L2 norm on q/k (如需要)
          │ ├── _align_seqs() → 序列对齐到 BT 倍数            chunk_fwd.py:566
          │ ├── gate padding 修复: g[padding] = -1e4           chunk.py:176-185
          │ └── compute_padded_cu_seqlens()                    tops/utils.py:130

          └── chunk_kda_fwd(_skip_align=True, cu_seqlens=aligned_cu)
              │                                              chunk_fwd.py:623

              ├── [Stage 1+2] kda_fwd_intra_fused()           chunk_intra_fwd_fused.py:586
              │     │
              │     │ bf16 varlen 路径:
              │     └── kda_fwd_intra_fused_varlen()           chunk_intra_fwd_fused.py:438
              │           └── Pallas #1: _fused_gate_intra_kernel  chunk_intra_fwd_fused.py:37
              │               功能: 门控激活 + cumsum + BC=16 Aqk/L + Neumann 求逆
              │               Grid: (1, H, NC)
              │               BlockSpec: 每块 [1, 1, BT, D]
              │               返回: w, u, qg, kg, Aqk, Akk, g_cumsum

              │     fp32 varlen fallback:
              │     └── kda_gate_chunk_cumsum() → kda_fwd_intra()
              │           分别做 S1 和 S2 (不做 BC=16 融合)

              └── [Stage 3+4, varlen]
                  chunk_kda_fwd_h_o_varlen()                   chunk_fwd.py:359
                    └── Pallas #2: _chunk_kda_fwd_h_o_varlen_kernel  chunk_fwd.py:247
                        功能: 状态传播 + 输出计算 (融合)
                        Grid: (H, NT), PrefetchScalarGridSpec(num_scalar_prefetch=2)
                        Scratch: VMEM (K_PADSIZE, V_ALIGNED)
                        BlockSpec: 每块 [1, 1, BT, D]
                        返回: o [1, T, H, V], final_state [N, H, K, V]

              [Varlen 后处理]
              └── _unalign_output() → 反向 gather 回原始 T      chunk_fwd.py:601

2.2 Backward 调用链(默认 varlen)

text
_chunk_kda_bwd_custom()                                       chunk.py:268

  │ [Varlen 预处理]
  │ ├── _align_seqs(do) → 对齐梯度
  │ ├── 转置所有张量 [B,T,H,X] → [H,B,T,X]
  │ └── 计算 aligned cu_seqlens / chunk_indices

  └── chunk_kda_bwd(cu_seqlens=aligned_cu, chunk_indices=...)  chunk_bwd.py:815

        ├── [S0] 重计算前向中间结果
        │   ├── kda_gate_chunk_cumsum() (如 use_gate_in_kernel)
        │   │     └── chunk_local_cumsum_vector()               cumsum.py:492
        │   │           └── Pallas #8: _chunk_cumsum_kernel_varlen  cumsum.py:314
        │   ├── recompute_w_u_fwd() (CPU reference)
        │   └── chunk_gated_delta_rule_fwd_h()                  chunk_delta_h.py:844
        │         └── Pallas #3: _chunk_gated_delta_rule_fwd_varlen_kernel  chunk_delta_h.py:188
        │              Grid: (H, NT), PrefetchScalarGridSpec(2)
        │              Scratch: VMEM (K_PADSIZE, V_ALIGNED)

        ├── [S1] dAqk + dv
        │   chunk_kda_bwd_dAv_kernel()                          chunk_bwd.py:470
        │     └── Pallas #4: _chunk_kda_bwd_dAv_kernel          chunk_bwd.py:401
        │          Grid: (BH * NT,), 1D 扁平化
        │          计算: dA = do @ v^T * scale, dv = A_masked^T @ do

        ├── [S2] 反向状态传播
        │   chunk_gated_delta_rule_bwd_dhu_kernel()             chunk_bwd.py:657
        │     └── Pallas #5: _chunk_gated_delta_rule_bwd_dhu_kernel  chunk_bwd.py:574
        │          Grid: (H, NT), PrefetchScalarGridSpec(2)
        │          Scratch: VMEM (K, V)
        │          Varlen 动态切片: dht/dh0 按 segment 切片避免加载全 N 维度
        │          反向迭代: i_t = NT - 1 - i_c

        ├── [S3] WY 融合反向
        │   chunk_kda_bwd_wy_dqkg_fused_kernel()                chunk_bwd.py:220
        │     └── Pallas #6: _chunk_kda_bwd_wy_dqkg_fused_kernel  chunk_bwd.py:25
        │          Grid: (BH * NT,), 展开 K/V 循环
        │          输入11个, 输出6个: dq, dk, dv2, dg, db, dA

        ├── [S4] Intra-chunk 反向
        │   chunk_kda_bwd_intra()                               chunk_intra.py:588
        │     └── Pallas #7: kda_intra_chunk_bwd_kernel_subchunk  chunk_intra.py:306
        │          Grid: (H, B, NT), BC=16 子块分解
        │          累加 inter 和 intra 贡献

        ├── [S5] 反向累积和
        │   chunk_local_cumsum_vector(reverse=True)             cumsum.py:492
        │     └── Pallas #8: _chunk_cumsum_kernel_varlen         cumsum.py:314
        │          Hillis-Steele 扫描 + fori_loop over chunks

        └── 门控反向 (如需要)
            kda_gate_bwd()                                      gate.py:95
        返回: dq, dk, dv, dg, db, dh0, dA, dbias

  [Varlen 后处理]
  ├── 转置 [H,B,T,X] → [B,T,H,X]
  ├── L2 norm backward (如需要)
  └── _unalign_output(dq, dk, dv, dg, db)

2.3 Pallas Varlen Kernel 汇总(默认路径)

#Kernel 函数文件:行调用者阶段
1_fused_gate_intra_kernelchunk_intra_fwd_fused.py:37kda_fwd_intra_fused_varlen()FWD S1+S2
2_chunk_kda_fwd_h_o_varlen_kernelchunk_fwd.py:247chunk_kda_fwd_h_o_varlen()FWD S3+S4
3_chunk_gated_delta_rule_fwd_varlen_kernelchunk_delta_h.py:188_chunk_gated_delta_rule_fwd_varlen()BWD S0
4_chunk_kda_bwd_dAv_kernelchunk_bwd.py:401chunk_kda_bwd_dAv_kernel()BWD S1
5_chunk_gated_delta_rule_bwd_dhu_kernelchunk_bwd.py:574chunk_gated_delta_rule_bwd_dhu_kernel()BWD S2
6_chunk_kda_bwd_wy_dqkg_fused_kernelchunk_bwd.py:25chunk_kda_bwd_wy_dqkg_fused_kernel()BWD S3
7kda_intra_chunk_bwd_kernel_subchunkchunk_intra.py:306_kda_intra_chunk_bwd_subchunk_pallas()BWD S4
8_chunk_cumsum_kernel_varlencumsum.py:314_chunk_local_cumsum_pallas()BWD S0/S5

2.4 关键工具函数

函数文件:行用途
_segment_ids_to_packed()chunk_fwd.py:203多 batch segment_ids → 平坦化 B=1 布局
_align_seqs()chunk_fwd.py:566每个序列对齐到 BT 倍数, 零填充
_unalign_output()chunk_fwd.py:601反向 gather, 丢弃 padding 输出
prepare_chunk_indices()tops/utils.py:255返回 [NT, 2] 的 (seq_id, chunk_id)
compute_padded_cu_seqlens()tops/utils.py:130对齐后的 cu_seqlens
segment_ids_to_seqlens()tops/utils.py:1931D segment_ids → cu_seqlens
_build_chunk_map()common/chunk_h.py:13searchsorted 建立 chunk→seq 映射
chunk_local_cumsum_vector()common/cumsum.py:492门控 chunk 内累积和分发器

2.5 _align_seqs — Padding 机制

_align_seqs() @ chunk_fwd.py:566-598

python
def _align_seqs(tensors_4d, tensors_3d, cu_seqlens, align):
    N = cu_seqlens.shape[0] - 1
    T_old = tensors_4d[0].shape[1]

    seg_lens = cu_seqlens[1:] - cu_seqlens[:-1]         # 如 [100, 50]
    padded_lens = ((seg_lens + align - 1) // align) * align  # → [128, 64]
    padded_cu = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32),
                                  jnp.cumsum(padded_lens)])  # → [0, 128, 192]
    T_new = ((T_old + N * (align - 1) + align - 1) // align) * align

    # gather_idx 映射: 真实位置→原位置, padding 位置→T_old (零填充区)
    gather_idx = jnp.full(T_new, T_old, dtype=jnp.int32)
    gather_idx = jax.lax.fori_loop(0, N, _build_gather, gather_idx)

    # 先 pad 零, 再 gather
    def repack_4d(t):
        return jnp.pad(t, ((0,0), (0, T_new-T_old), (0,0), (0,0)))[:, gather_idx]

结果:padding 位置的 q=k=v=beta=0。

2.6 门控中和修复

chunk.py:176-185

python
if use_gate_in_kernel:
    orig_lens = _orig_cu_seqlens[1:] - _orig_cu_seqlens[:-1]
    aligned_starts = cu_seqlens[:-1]
    pos = jnp.arange(T)
    in_range = (pos[None, :] >= aligned_starts[:, None]) & (
        pos[None, :] < (aligned_starts + orig_lens)[:, None]
    )
    valid_mask = in_range.any(axis=0)  # [T], True=真实 token
    g = jnp.where(valid_mask[None, :, None, None], g, -1e4)

_align_seqs 将 gate 填充为 0,但 softplus(0 + dt_bias) 可能非零。将 padding 位置的 g 设为 -1e4 确保 softplus(-1e4 + dt_bias) ≈ 0

2.7 _unalign_output — 输出反对齐

_unalign_output() @ chunk_fwd.py:601-620

python
def _unalign_output(o, orig_cu_seqlens, aligned_cu_seqlens, T_out):
    N = orig_cu_seqlens.shape[0] - 1
    safe_default = aligned_cu_seqlens[-1]  # 尾部零位置,确保读取零而非位置0

    gather_idx = jnp.full(T_out, safe_default, dtype=jnp.int32)
    gather_idx = jax.lax.fori_loop(0, N, _build_gather, gather_idx)
    return o[:, gather_idx]  # [1, T_out, H, V]

关键:safe_default = aligned_cu_seqlens[-1] 指向对齐后所有序列末尾的 padding 区域(保证为零),避免之前 safe_default=0 的 bug(会重复读取位置 0 的输出)。

2.8 具体示例

seq_lens = [100, 50], BT = 64 为例:

python
# === _align_seqs 之后 ===
# T_old = 150
# padded_lens = [128, 64]  (BT=64 对齐)
# padded_cu[-1] = 192  (128 + 64, 有效对齐后总长度)
# T_new = 320  ((_align_seqs 内部申请的安全物理上界))
# padded_cu = [0, 128, 192]
#
# 布局:
#   positions   0..99:   seq0 真实 (原 0..99)
#   positions 100..127:  seq0 padding (零)
#   positions 128..177:  seq1 真实 (原 100..149)
#   positions 178..191:  seq1 padding (零)
#   positions 192..319:  安全尾部 padding (零, 不属于任何真实 seq chunk)
#
# Stage 1+2 fused intra 的张量 chunk 结构 (NC = T_new / BT = 320/64 = 5):
#   chunk 0: positions   0..63   = seq0 真实
#   chunk 1: positions  64..127  = seq0 真实(36) + padding(28)
#   chunk 2: positions 128..191  = seq1 真实(50) + padding(14)
#   chunk 3: positions 192..255  = 安全尾部 padding (纯零)
#   chunk 4: positions 256..319  = 安全尾部 padding (纯零)
#
# Stage 3+4 fused state/output 的 varlen chunk 结构 (NT = len(chunk_indices) = 3):
#   chunk_indices 只覆盖 padded_cu[-1] = 192 以内的 3 个有效对齐 chunk
#   Stage 1+2 产生的 chunk 3/4 是纯零安全尾部,不会被 Stage 3+4 读取

这里要区分两个长度和两个 chunk 计数:T_new_align_seqsgather_idxjnp.pad 申请的安全上界,Stage 1+2 会按 T_new / BT reshape 成完整 chunk;padded_cu[-1] 才是所有序列按 BT 对齐后的有效总长度,Stage 3+4 的 chunk_indices 只覆盖这部分有效 chunk。超过 padded_cu[-1] 的尾部位置由 jnp.pad(..., 0) 提供,是纯零安全尾部,不会被 Stage 3+4 读取,也不会被 _unalign_output() gather 回真实输出。

2.9 Padding 在求解中的数学解耦

Stage 1+2:L 矩阵块对角结构

_fused_gate_intra_kernel @ chunk_intra_fwd_fused.py:37 中的关键推导:

门控累积和:

python
g_f32 = -exp(A_val) * softplus(g_f32 + dt_bias)  # padding: g=-1e4 → g_act≈0
g_cumsum = tril @ g_f32  # padding 行: 平坦继承最后真实 token 的 cumsum

L 矩阵(BC=16 子块分解):

python
# L[i,j] = beta_i * (k_eg_i @ k_eng_j^T)
# k_eg = k * exp2(g_cumsum - g_ref)
# k_eng = k * exp2(g_ref - g_cumsum)

对于 padding 位置:

  • padding 行 ik_i=0beta_i=0L[padding, :] = 0
  • padding 列 jk_j=0k_i @ k_j^T = 0L[:, padding] = 0

L 矩阵呈现块对角结构

text
L = [L_real,  0    ]    ← BT_real × BT_real 有效块
    [0,       0    ]    ← BT_pad × BT_pad 零块

求解 (I+L)x = rhs:

  • padding 行: RHS = [v_beta=0, k_eg_beta=0, I_padding_row]
  • (I + 0)^{-1} @ [0, 0, I_row] = [0, 0, I_row]
  • u[padding]=0, w[padding]=0, A_inv[padding]=one_hot(padding)

真实行的解 (I + L_real)^{-1} @ rhs_real 与 padding 完全无关。

Stage 3+4:Padding 透明传播

_chunk_kda_fwd_h_o_varlen_kernel @ chunk_fwd.py:247 中的关键推导:

v_new 计算(Stage 3 修正值):

python
b_v_new = b_u - b_w @ b_h      # padding: 0 - 0 @ h = 0

输出计算(Stage 4):

python
b_qg = b_q * exp2(g - g_ref)    # padding: 0 * exp2(...) = 0
b_o_inter = scale * b_qg @ (b_h * exp2(g_ref))   # = 0
b_o_intra = tril(A) @ b_v_new                     # = tril(A) @ 0 = 0

状态更新:

python
b_gk_last = gk_ref[BT-1]        # padding chunk: = g_cumsum[last_real]
scratch *= exp2(b_gk_last)      # 按真实衰减率衰减
scratch += b_kg^T @ b_v_new     # padding: 0^T @ 0 = 0, 状态不变

2.10 逐 Token 追踪(chunk1: 36真实 + 28padding)

text
Token   | 64 65 ... 99 | 100 101 ... 127
--------|--------------|-----------------
类型    | 真实         | padding
g_orig  | 实值         | -1e4 (修复后)
g_act   | 实值         | softplus(-1e4) ≈ 0
g_cumsum| 递增         | 平坦 = cumsum[99]
q,k,v   | 实值         | 0 (align 填充)
beta    | 实值         | 0 (align 填充)
--------|--------------|-----------------
L 矩阵  | L[0:36, 0:36] 实值   | L[36:64, :] = 0
        | L[0:36, 36:64] = 0   | L[:, 36:64] = 0
RHS     | v_beta, k_eg_beta 实值 | 0, 0
--------|--------------|-----------------
u       | 实值 (来自求解)        | 0
w       | 实值                   | 0
kg      | 实值                   | 0
qg      | 实值                   | 0
--------|--------------|-----------------
v_new   | u - w@h 实值           | 0 - 0@h = 0
o       | inter + intra 实值     | 0 + 0 = 0
状态    | h *= exp2(cumsum[99])  | h *= exp2(cumsum[99])
更新    | h += kg^T @ v_new 实值 | h += 0 = h

2.11 当前 Pallas Varlen 实现的 padding / 越界处理影响

本表默认按 Pallas varlen 路径审计。当前 Pallas varlen 路径和 FLA 的核心差异是:Pallas 不依赖 kernel 内 boundary_check 给越界 lane 兜底,而是在进入 Pallas kernel 之前把 T/K/V 维度补到 kernel 需要的安全形状。Pallas varlen kernel 中即使设置了 disable_bounds_checks=True,读写对象也是已经 jnp.pad_align_seqs 过的张量,因此不会出现 FLA tl.load(block_ptr, boundary_check=...) 那种 undefined lane 语义。

API / 阶段代表 kernel / 函数padding / 越界来源当前处理mask / pad 后的值对结果的影响
chunk_kda() varlen 入口_segment_ids_to_packed() / segment_ids_to_seqlens()多 batch segment_ids 被展平到 B=1只做 layout pack 和 cu_seqlens 生成,不引入 padding 值没有新增 padding lane无风险
varlen 前向对齐_align_seqs()每个序列长度不是 BT 整数倍jnp.pad(..., 0),再用 gather_idx 把每段放到 BT 对齐位置q/k/v/beta 的 padding token 变成 0g 初始 padding 为 0无风险,但 use_gate_in_kernel=True 还依赖下一行 gate 修复
gate padding 修复_chunk_kda_fwd_custom() / chunk_kda_fwd()_align_seqs() 会把 raw gate padding 成 0对 padding 位置执行 jnp.where(valid_mask, g, -1e4)padding g 变成 -1e4softplus(-1e4 + dt_bias) ≈ 0无风险
use_qk_l2norm_in_kernel=Truetops.cpu.ops.common.l2norm_fwd/bwdK 非 2 的幂或非 block 对齐当前 Pallas 路径在 JAX/CPU reference 上做整张量 norm,不使用 block-pointer load没有 undefined feature tail;padding token 的 q/k 已经是 0无风险
FWD S1+S2 fused intra_fused_gate_intra_kernel / kda_fwd_intra_fused_varlen()T tail 来自 _align_seqs();K/V 维按真实维度进入该 kernel要求 T % BT == 0;直接 reshape 成完整 chunk;因果上三角用零矩阵或 mask 清掉padding token 的 q/k/v/beta=0;padding gate 激活约为 0;非因果 Aqk/Akk entry 为 0无风险
FWD S3+S4 fused state/output_chunk_kda_fwd_h_o_varlen_kernelK tail、V tail、额外安全 T chunk、空序列 final_statewrapper 先把 K/V 补到 128 对齐,T 额外补 BT;输出再裁回真实 K/V/T;empty seq 用 jnp.where 填初始状态或 0K/V/T padding 都是 0;empty final_state 变成 initial_state0无风险
BWD 对齐_chunk_kda_bwd_custom() + _align_seqs(do)do 原始长度不是 BT 对齐backward 开始先把 do 对齐到和 forward residual 一致的 Tpadding do 变成 0无风险
BWD S0 重计算kda_gate_chunk_cumsum() / recompute_w_u_fwd() / chunk_gated_delta_rule_fwd_h()aligned padding token、K/V 对齐 padding使用 forward residual 的对齐张量;h chunk 数不足时 jnp.pad(..., 0)padding token 继续满足 q/k/v/beta/do=0;padding h chunk 为 0无风险
BWD S1 dAqk + dv_chunk_kda_bwd_dAv_kernelV tail、A 上三角、padding tokenblock_V 默认等于 V,若自定义必须整除;A 用 causal mask 清上三角非因果 A/dA entry 变成 0.0;padding v/do0无风险
BWD S2 reverse state_chunk_gated_delta_rule_bwd_dhu_kernelreverse chunk 顺序、序列边界、dht=Nonechunk_to_seq 定位 seq;在 t0 + BT >= eos 重置 scratch;dht=None 时使用全 0dht 时 final-state 梯度为 0;序列间 scratch 不串联无风险
BWD S3 WY fused_chunk_kda_bwd_wy_dqkg_fused_kernelK/V tile tail、A 上三角、padding tokenK % BK == 0V % BV == 0;默认 BK=KBV=Vm_lower 清非严格下三角非严格下三角外的 dAkk entry 为 0.0;padding token 输入为 0无风险
BWD S4 intra backwardkda_intra_chunk_bwd_kernel_subchunkA 上三角、非严格下三角、aligned padding tokendAqk 用 causal mask;dAkk 用 strict-lower mask;T 已经 BT 对齐非因果 dAqk0.0;非严格下三角外 dAkk0.0无风险
BWD S5 reverse cumsum_chunk_cumsum_kernel_varlenS tail、BH tail、T 安全尾块、padding tokenS/BH/T 都先 jnp.pad(..., 0);输出裁回 [:BH, :T, :S];最后再 _unalign_outputS/BH/T 额外 lane 为 0;padding token 输出最终被丢弃无风险
gate backwardkda_gate_bwd()padding g_org 和可能存在的 padding dygg_org padding 已经是 -1e4dg/dbias/dA 按激活函数导数计算padding gate 的 softplus/sigmoid 导数约为 0;padding dg 最终被 _unalign_output 丢弃无风险

当前 Pallas varlen 扫描结论:没有发现类似 FLA block-pointer boundary_check 默认 undefined 的数值风险。需要维持正确性的关键条件是:所有 varlen 调用必须继续走 _align_seqs() / aligned cu_seqlensuse_gate_in_kernel=True 时必须保留 padding g=-1e4 的修复;新增 Pallas varlen kernel 如果使用 disable_bounds_checks=True,必须先在 wrapper 层把 T/K/V 和额外安全 chunk 明确补成确定值。


三、两种策略对比

维度FLA (Triton, GPU)Pallas (TPU)
Padding 策略无物理 padding_align_seqs 物理填充到 BT 倍数
边界处理多重条件检查 + 循环边界限制;boundary_check load 默认 undefinedL 矩阵数学解耦 + 门控中和
索引机制prepare_chunk_indices → 每个 kernel 内动态解码_align_seqs → 统一的 transpose+reshape
Chunk 结构最后一个 chunk 只有部分 token所有 chunk 都是 BT 个 token
计算量最后一个 chunk 做更少工作所有 chunk 做相同工作(padding 部分为无效计算)
内存T 不变T 增大(对齐后)
正确性保证T 维 varlen tail 由边界逻辑隔离;K/V/D 维归约 tail 需要对齐或显式补零k=0 → L=0 → 解耦
Kernel 复杂度每个 kernel 都需要 IS_VARLEN 路径默认 varlen wrapper 先对齐;kernel 主要处理完整 BT chunk
硬件适配GPU (Triton)TPU (K/V 128 对齐, PrefetchScalarGridSpec)

四、结论

Pallas 实现:padding 不影响真实 token 的计算结果。

原因可归结为一条数学不变量:

k=0 → L 矩阵的块对角结构 → 真实 token 与 padding token 在求解中完全解耦

具体来说:

  1. k=0 且 beta=0_align_seqs 的零填充使 L 矩阵中与 padding 相关的行/列全为零
  2. L_real 独立:真实 token 的求解方程 (I + L_real) @ X_real = RHS_real 与 padding 完全无关
  3. g=-1e4 中和:门控修复确保 g_act ≈ 0,padding 不贡献衰减
  4. 状态透明:padding 处 kg=0, v_new=0,状态传播不受影响
  5. 输出正确裁剪_unalign_output 丢弃 padding 输出

FLA 实现:T 维 varlen tail 通常不影响真实 token,但 block-pointer load 不是自动补零。

FLA 的工程策略是“不构造物理 padding token”,并在每个 Triton kernel 中解码当前序列的 bos/eos/T。这能避免 Pallas 那种对齐后 padding token 参与完整 chunk 计算的问题。

但从 tl.load 语义看,FLA 还隐含一个额外条件:所有参与 tl.sum / tl.dot 的 K/V/D 尾部 lane 必须是有效值、被 mask 成 0,或通过 padding_option="zero" 明确补零。否则 boundary_check 返回的 undefined lane 可能污染真实结果。尤其是 use_qk_l2norm_in_kernel=True 时的 l2norm_fwd_kernel / l2norm_bwd_kernel 小 D 路径,非 2 的幂 head dim 会把 undefined feature tail 纳入 norm reduction。

两者的差异在于工程策略,而非数学正确性。

  • FLA 更"干净":不引入无用数据,每个 chunk 只做必要计算,但每个 kernel 都需要 IS_VARLEN 分支,并且 block-pointer 越界 load 需要单独审计
  • Pallas 更"规整":通过前置对齐统一 chunk 大小,适配 TPU 的硬件对齐要求(128 对齐、PrefetchScalarGridSpec),代价是部分无效计算和 T 维度膨胀

五、实证验证

5.1 Pallas Varlen E2E 测试

bash
PALLAS_INTERPRET=1 JAX_PLATFORMS=cpu uv run pytest tests/ops/kda/test_varlen_e2e.py -v

结果:44 passed 全部通过。

测试覆盖场景:

测试函数用例数验证内容
test_chunk_kda_fwd_varlen19非 BT 对齐 forward(多种 seq_lens、dtype、head_dim 组合)
test_chunk_kda_varlen_bwd19非 BT 对齐 backward
test_chunk_kda_varlen_bwd_no_oob_tail_chunks3T_pad >> sum(seq_lens) 时 tail chunks 不越界
test_chunk_kda_varlen_fwd_padding_tail_not_duplicate1_unalign_output 使用 safe_default 而非位置 0
test_chunk_kda_varlen_empty_seq_final_state2N_pad > 真实序列数时的 empty seq 处理

覆盖的组合:

  • [30, 50][45, 80, 20][100] 等非 BT=64 对齐序列
  • T_pad > sum(seq_lens) 真实 padding 场景
  • N_pad=8 仅 2 个真实序列(尾随空序列)
  • initial_state=None 和非 None
  • bf16 精度

5.2 Segment E2E 测试

bash
PALLAS_INTERPRET=1 JAX_PLATFORMS=cpu uv run pytest tests/ops/kda/test_segment_e2e.py -v

5.3 FLA 库测试

FLA 库不引入物理 padding 数据,每个 chunk 仅包含真实 token。主 KDA varlen 路径的 T 维 tail 由边界逻辑隔离;但还应补充以下回归用例来覆盖 boundary_check 的 undefined lane 风险:

场景预期
use_qk_l2norm_in_kernel=TrueK 非 2 的幂(如 96/192)l2norm_fwd/bwd 应与 PyTorch reference 一致;若不一致,需给 block-pointer load 补 padding_option="zero"
KDA forward 中 KBK 对齐Aqk/Akk/w/kg/qg/o 不应受 K tail undefined lane 影响
KDA backward 中 VBV 对齐dA/dv/dq/dk/dg/db 不应受 V tail undefined lane 影响
reverse cumsum 最后一个 varlen chunk 不满 BTtail lane 不应进入真实 token 的 reverse prefix 结果