Varlen Chunk-KDA: Padding 对结果的影响分析
背景
在 varlen(变长序列)模式下,每个序列的长度 L 可能不是 chunk_size(记为 BT,通常 64)的整数倍。例如 L=100, BT=64 时,需要 2 个 chunk:第一个满 64 tokens,第二个只有 36 个真实 token。
FLA 库(Triton 实现,v0.4.2)和本地 Pallas 实现(tops/ops/kda/)采用不同的策略处理这种情况。本分析验证两种策略下,padding 是否会影响真实 token 的计算结果。
一、FLA 库(Triton 实现)
安装路径:.venv/lib/python3.12/site-packages/fla/ops/kda/
核心思路:不对输入张量做任何物理填充,每个 chunk 内只处理真实存在的 token。
1.0 公共 API
入口函数:chunk_kda() @ fla/ops/kda/chunk.py:144
def chunk_kda(
q, k, v, g, beta, # [B, T, H, D]
scale=None,
initial_state=None, # [N, H, K, V] 用于 varlen
output_final_state=False,
use_qk_l2norm_in_kernel=False,
use_gate_in_kernel=False,
cu_seqlens=None, # [N+1], varlen 模式入口
cu_seqlens_cpu=None,
safe_gate=False,
lower_bound=None,
disable_recompute=False,
return_intermediate_states=False,
cp_context=None,
transpose_state_layout=False,
**kwargs,
):内部使用 torch.autograd.Function:
ChunkKDAFunction.forward()@chunk.py:18→ 调用chunk_kda_fwd()ChunkKDAFunction.backward()@chunk.py:101→ 调用chunk_kda_bwd()
Varlen 约定(chunk.py:285-295):当 cu_seqlens is not None 时,要求 B=1、T = sum(seq_lens)(所有序列平坦化拼接),initial_state.shape[0] == len(cu_seqlens)-1。
1.1 Forward 调用链
chunk_kda() chunk.py:144
└── ChunkKDAFunction.forward() chunk.py:18
├── (可选) l2norm_fwd(q) / l2norm_fwd(k) modules/l2norm.py
├── prepare_chunk_indices(cu_seqlens, BT) [纯 PyTorch] utils/index.py:112
└── chunk_kda_fwd() chunk_fwd.py:16
│
├── [Phase A] 门控预处理
│ ├── use_gate_in_kernel=True:
│ │ kda_gate_chunk_cumsum() gate.py:415
│ │ └── Triton: kda_gate_chunk_cumsum_vector_kernel gate.py:362
│ └── use_gate_in_kernel=False:
│ chunk_local_cumsum() utils/cumsum.py:429
│ └── Triton: chunk_local_cumsum_vector_kernel cumsum.py:86
│
├── [Phase B] Intra-Chunk
│ chunk_kda_fwd_intra() chunk_intra.py:951
│ ├── safe_gate=True:
│ │ └── Triton: chunk_kda_fwd_kernel_intra_sub_chunk chunk_intra.py:817
│ ├── safe_gate=False:
│ │ └── Triton: chunk_kda_fwd_kernel_intra_token_parallel chunk_intra_token_parallel.py:25
│ ├── Triton: chunk_kda_fwd_kernel_inter_solve_fused chunk_intra.py:37
│ └── recompute_w_u_fwd() wy_fast.py:210
│ └── Triton: recompute_w_u_fwd_kda_kernel wy_fast.py:32
│ 返回: w, u, qg, kg, Aqk, Akk
│
├── [Phase C] 跨 Chunk 状态传播
│ chunk_gated_delta_rule_fwd_h() common/chunk_delta_h.py:655
│ └── Triton: chunk_gated_delta_rule_fwd_kernel_h_blockdim64 chunk_delta_h.py:35
│ 返回: h, v_new, final_state
│
└── [Phase D] 输出计算
chunk_gla_fwd_o_gk() gla/chunk.py:863
└── Triton: chunk_gla_fwd_kernel_o gla/chunk.py:308
返回: o1.2 Backward 调用链
chunk_kda_bwd() chunk_bwd.py:410
│
├── [S0] 重计算前向中间结果
│ ├── kda_gate_chunk_cumsum() (如需要) gate.py:415
│ ├── recompute_w_u_fwd() wy_fast.py:210
│ └── chunk_gated_delta_rule_fwd_h() common/chunk_delta_h.py:655
│
├── [S1] dAqk + dv
│ chunk_kda_bwd_dAv() chunk_bwd.py:294
│ └── Triton: chunk_kda_bwd_kernel_dAv chunk_bwd.py:42
│
├── [S2] 反向状态传播
│ chunk_gated_delta_rule_bwd_dhu() common/chunk_delta_h.py:715
│ └── Triton: chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 chunk_delta_h.py:345
│
├── [S3] WY 融合反向
│ chunk_kda_bwd_wy_dqkg_fused() chunk_bwd.py:345
│ └── Triton: chunk_kda_bwd_kernel_wy_dqkg_fused chunk_bwd.py:125
│
├── [S4] Intra-chunk 反向
│ chunk_kda_bwd_intra() chunk_intra.py:1030
│ └── Triton: chunk_kda_bwd_kernel_intra chunk_intra.py:367
│
├── [S5] 反向累积和 + 门控反向
│ ├── chunk_local_cumsum(reverse=True) utils/cumsum.py:429
│ └── kda_gate_bwd() (如需要) gate.py:240
│
返回: dq, dk, dv, db, dg, dh0, dA, dbias1.3 Varlen 索引映射
prepare_chunk_indices() @ fla/ops/utils/index.py:112:
@tensor_cache
def prepare_chunk_indices(cu_seqlens, chunk_size):
lens = torch.diff(cu_seqlens) # 各序列长度, 如 [100, 50]
n_chunks = triton.cdiv(lens, chunk_size) # 如 [2, 1]
indices = torch.cat([torch.arange(n) for n in n_chunks])
# indices = [0, 1, 0] (seq0: chunk0, chunk1; seq1: chunk0)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1)
# → [[0, 0], [0, 1], [1, 0]] = (seq_id, local_chunk_id)prepare_chunk_offsets() @ fla/ops/utils/index.py:125 计算跨序列的累积 chunk 数,用于索引 [NT_total, H, K, V] 张量。
1.4 Kernel 内的边界检查
每个 Triton kernel 通过 IS_VARLEN 启发式开关(基于 cu_seqlens is not None)获取当前序列边界:
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})Kernel 内部:
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32)
T = eos - bos # 动态覆盖编译时常量 T
else:
bos, eos = i_b * T, i_b * T + T所有边界检查汇总:
| 检查方式 | 位置 | 效果 |
|---|---|---|
i_t * BT >= T → return | inter_solve_fused 开头 | 整个 chunk 超出序列则跳过 |
i_ti >= T → return | intra_sub_chunk 开头 | 子块起始越界则跳过 |
if i_tc1 < T: / i_tc2 < T: / i_tc3 < T: | inter_solve_fused K 循环 | 条件加载越界子块 |
m_tc1 = (i_tc1 + o_i) < T | inter_solve_fused | 子块内逐元素有效性掩码 |
m_c = o_c < T | intra_sub_chunk | 子块内逐元素有效性掩码 |
for i in range(2, min(BC, T - i_tc0)) | 前向替换循环 | 按有效行数截断 |
min(BC//2, T - i_ti - 1) | g_ref 加载 | 门控参考点不越界 |
last_idx = min((i_t+1)*BT, T) - 1 | 状态传播 | chunk 最后有效 token |
boundary_check=(0, 1) | block pointer load/store | 阻止真实越界访存;但 load 未指定 padding_option 时,越界 lane 的返回值是 undefined,不是 0 |
1.4.1 Triton tl.load 越界语义
需要区分两类写法:
# 标量/指针 + mask: 越界 lane 明确填 other
tl.load(ptr + offsets, mask=mask, other=0.0)
# block pointer + boundary_check: 越界 lane 不访存,但默认返回 undefined
tl.load(block_ptr, boundary_check=(0, 1))
# block pointer 要明确补零时,需要显式写 padding_option
tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")因此,boundary_check 只能保证不发生非法内存访问;它本身不保证越界值是 0。是否影响真实结果,取决于这些 undefined lane 是否会进入真实 token / 真实 feature 的 reduction、dot 或递推。
1.4.2 各 API / kernel 的越界 load 影响
| API / 阶段 | 代表 kernel | 越界来源 | 当前处理 | mask 后的值 | 对结果的影响 |
|---|---|---|---|---|---|
l2norm_fwd() 大 D 路径 | l2norm_fwd_kernel1 | BD > D 的 feature tail | mask=cols < D, other=0.0 | 无效 feature lane 读成 0.0;store mask 不写回 | 无风险 |
l2norm_bwd() 大 D 路径 | l2norm_bwd_kernel1 | BD > D 的 feature tail | mask=cols < D, other=0.0 | y/dy 无效 feature lane 读成 0.0;store mask 不写回 | 无风险 |
l2norm_fwd() 小 D 路径 | l2norm_fwd_kernel | BD > D 或最后一个 BT block 超过 T | boundary_check,未指定 padding_option | load 越界 lane 是 undefined;store 越界 lane 不写回 | 有风险:影响 fwd 的 rstd / y |
l2norm_bwd() 小 D 路径 | l2norm_bwd_kernel | BD > D 或最后一个 BT block 超过 T | boundary_check,未指定 padding_option | load 越界 lane 是 undefined;store 越界 lane 不写回 | 有风险:影响 bwd 的 dx |
kda_gate_chunk_cumsum() | kda_gate_chunk_cumsum_vector_kernel | T tail、S/K tail | boundary_check;cumsum 沿 T 维;store 带 boundary_check | load 越界 lane 是 undefined;没有显式改值;store 越界 lane 不写回 | 无风险,仅限 forward cumsum |
chunk_local_cumsum() | chunk_local_cumsum_vector_kernel | T tail、S/K tail | 与 gate cumsum 类似 | load 越界 lane 是 undefined;若使用 mask/other 路径则无效 lane 为 0.0;store 越界 lane 不写回 | 无风险,仅限 forward cumsum |
chunk_kda_fwd_intra() token-parallel 路径 | chunk_kda_fwd_kernel_intra_token_parallel | T tail、K tail | i_t >= T → return;循环 j < min(T, ...);K tail 用 m_k 在部分表达式中置 0 | 越界 token 直接 return 或不进入循环;tl.where(m_k, ..., 0.0) 把无效 K lane 写成 0.0 | 有风险:K tail 未完全清零时影响 fwd 的 intra A / o |
chunk_kda_fwd_intra() inter solve 路径 | chunk_kda_fwd_kernel_inter_solve_fused | 子块 T tail、K tail、Akk/Aqk tail | 子块起点 if i_tc* < T;行 mask m_tc*;store 带 boundary_check | 不存在的子块跳过;tl.where(m_tc*[:, None], ..., 0) 把无效 token 行的 gate factor 写成 0;tl.load(..., mask=m_k, other=0) 把 K tail 的 reference gate 写成 0;store 越界 lane 不写回 | 有风险:K tail 未补零时影响 fwd 的 Aqk / Akk |
recompute_w_u_fwd() | recompute_w_u_fwd_kda_kernel | T tail、K tail、V tail | T tail 通过 last_idx、m_t、store boundary 隔离;K/V tail 多为 block-pointer load | tl.where(m_t[:, None], ..., 0) 把无效 token 行的 kg 写成 0;tl.load(..., mask=m_k, other=0.) 把 K tail 的 last gate 写成 0.0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回 | 有风险:K tail 未补零时影响 fwd 的 w / kg / qg |
chunk_gated_delta_rule_fwd_h() | chunk_gated_delta_rule_fwd_kernel_h_blockdim64 | 最后 chunk T tail、K/V block tail | last_idx 取最后真实 token;T tail 在 gate 衰减中用 m_t 置 0;store boundary | tl.where(m_t, exp(...), 0) 把无效 token 行的衰减乘子写成 0,从而无效行 b_v 变为 0;tl.load(..., mask=o_k < K, other=0.) 把 K tail 的 last gate 写成 0.0;store 越界 lane 不写回 | 有风险:K tail 未补零时影响 fwd 的 state h |
chunk_gla_fwd_o_gk() 输出 | chunk_gla_fwd_kernel_o | T tail、K tail、V tail、A tail | store boundary;A 只用因果下三角 m_s | tl.where(m_s, b_A, 0.) 把非因果的 A 上三角写成 0.0;block-pointer tail 是 undefined;store 越界 lane 不写回 | 有风险:K tail 未补零时影响 fwd 的 o |
chunk_kda_bwd_dAv() | chunk_kda_bwd_kernel_dAv | T tail、V tail、A tail | m_t/m_A 清掉无效 A;store boundary | tl.where(m_A, b_A, 0) 把无效 A entry 写成 0;tl.where(o_t[:, None] >= o_t, ..., 0.) 把非因果 dA entry 写成 0.0;store 越界 lane 不写回 | 有风险:V tail 未补零时影响 bwd 的 dA |
chunk_kda_bwd_wy_dqkg_fused() | chunk_kda_bwd_kernel_wy_dqkg_fused | T tail、K tail、V tail | m_t/m_last、last_idx 和 store boundary;K/V tail多为 block-pointer load | m_t 用于识别有效 token;m_last 只保留最后真实 token;tl.load(..., mask=m_k, other=0) 把 K tail 的 reference gate 写成 0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回 | 有风险:影响 bwd 的 dq / dk / dv / dg / db / dA |
prepare_wy_repr_bwd() | prepare_wy_repr_bwd_kda_kernel | T tail、K tail、V tail | store boundary;部分标量 tail 用 mask | tl.load(..., mask=m_k, other=0.) 的 K tail 标量读成 0.0;其他 block-pointer tail 是 undefined;store 越界 lane 不写回 | 有风险:影响 bwd 的中间梯度 dw / du |
chunk_kda_bwd_intra() | chunk_kda_bwd_kernel_intra | T tail、K tail、A tail | 有子块/循环边界和部分 mask=..., other=0 | mask=..., other=0 的无效 lane 读成 0;tl.where(..., 0.) 的无效项写成 0.0;store 越界 lane 不写回 | 有风险:K tail 未补零时影响 bwd 的 intra dA / dq / dk |
chunk_local_cumsum(reverse=True) | chunk_local_cumsum_vector_kernel | T tail、S/K tail | reverse cumsum;store boundary | 若只有 boundary_check,load tail 是 undefined;若使用 mask/other,无效 lane 是 0.0;store 越界 lane 不写回 | 有风险:影响 reverse cumsum 的真实 token 输出 |
kda_gate_bwd() | kda_gate_bwd_kernel | T tail、K tail | block-pointer boundary_check,部分 bias load boundary | block-pointer tail 是 undefined;store 越界 lane 不写回;dA 没有按 T/D tail mask 后再归约 | 有风险:影响 bwd 的 dg / db |
常见 mask 改写规则:
| 写法 | 无效 lane 的值 |
|---|---|
tl.load(ptr, mask=m, other=0) / other=0.0 | 读成 0 / 0.0 |
tl.where(m, x, 0) / tl.where(m, x, 0.0) | 写成 0 / 0.0 |
tl.where(m_t, exp(...), 0) 作为乘子 | 无效 token 的乘子为 0,对应贡献被清零 |
因果 mask,如 tl.where(m_s, A, 0.) 或 tl.where(m_A, A, 0) | 非因果或无效矩阵 entry 变成 0 / 0.0 |
tl.store(..., mask=m) 或 boundary_check store | 无效 lane 不写回,不是写成 0 |
if ...: return / if i_tc < T: / range(..., min(...)) | 对应 chunk、子块或循环迭代不执行,没有产生新值 |
关键修正:FLA 的 varlen 主路径确实没有物理 padding token,但这不等价于所有 block-pointer 越界 load 都自动补零。T 维 tail 多数被 return、min(T, ...)、m_t/m_tc、三角 mask 和 store boundary 隔离,所以不会回写真实 token;但 K/V/D 这类被 tl.dot 或 tl.sum 归约的尾部 lane,如果 block size 大于真实维度且没有显式 other=0 / padding_option="zero" / tl.where(mask, x, 0),就可能污染真实结果。
对当前 KDA 调用链,最明确的问题是可选的 use_qk_l2norm_in_kernel=True:当 head dim K 走 l2norm_fwd_kernel / l2norm_bwd_kernel 且 K 不是 BD 的整数覆盖边界(例如非 2 的幂)时,undefined feature tail 会直接进入 norm 的 reduction。修复方式是给 block-pointer load 加 padding_option="zero",或改回显式 mask/other=0.0。
1.5 具体示例
以 seq_lens = [100, 50], BT = 64 为例:
# 变量初始化
cu_seqlens = torch.tensor([0, 100, 150]) # 2 个序列, T=150
# prepare_chunk_indices 输出:
# chunk_indices = [[0, 0], ← seq0 的 chunk0 (64 tokens, 全真实)
# [0, 1], ← seq0 的 chunk1 (36 tokens, 无填充)
# [1, 0]] ← seq1 的 chunk0 (50 tokens, 无填充)
# NT = 3 (总 chunk 数), 而如果定长则 NT = ceil(150/64) = 3
# 但 varlen 下每个 chunk 的 T 值不同:
# seq0-chunk0: T=100, bos=0, eos=100
# seq0-chunk1: T=100, bos=0, eos=100, 仅处理 positions 64-99
# seq1-chunk0: T=50, bos=100, eos=150, 仅处理 positions 100-149Kernel 内部处理(以 chunk_kda_fwd_kernel_inter_solve_fused 为例 @ chunk_intra.py:37):
# chunk0 of seq0 (i_t = 0): i_t * BT = 0, T = 100, 0 < 100 ✓
# 处理 4 个子块 (BC=16): tc0=0, tc1=16, tc2=32, tc3=48 — 全部有效
# chunk1 of seq0 (i_t = 1): i_t * BT = 64, T = 100, 64 < 100 ✓
# 子块 tc0=64 (valid), tc1=80 (valid), tc2=96 (valid 但最后4个token越界)
# tc3=112 ≥ 100 → 跳过整个 tc3
# 在 tc2 中: m_tc2 = (96+o_i) < 100 → 只有 o_i ∈ {0,1,2,3} 有效
# 前向替换: for i in range(2, min(16, 100-96)) = range(2, 4) → 仅处理行 2,3
# chunk0 of seq1 (grid_idx = 2): chunk_indices[2] = [1, 0]
# 解码后 i_n=1, i_t=0,因此使用 seq 内 local chunk 坐标
# i_t * BT = 0, T = 50, bos=100, eos=150
# 子块 tc0=0, tc1=16, tc2=32, tc3=48 — 全部 < 50,全部会进入条件分支
# 在 tc3 中: m_tc3 = (48+o_i) < 50 → 只有 o_i ∈ {0,1} 有效
# 前向替换: for i in range(2, min(16, 50-48)) = range(2, 2) → 空循环1.6 结论
FLA 库完全不引入物理 padding 数据。最后一个 chunk 只包含真实 token。 Varlen 的 T 维尾部通常通过 kernel 内的 chunk 边界、行 mask、三角 mask 和 store boundary 隔离,因此不会写回真实 token。
但要注意:tl.load(block_ptr, boundary_check=...) 默认不是补零。只要越界 lane 进入 tl.sum / tl.dot 这类跨 feature 或跨 value 维度的归约,就可能影响真实结果。当前文档中的 FLA 正确性结论应理解为:T 维 varlen tail 在主 KDA 路径中被隔离;K/V/D 维 tail 还需要维度与 block size 对齐,或显式补零/掩码。
二、Pallas 实现(TPU)
实现路径:tops/ops/kda/、tops/ops/common/
核心思路:通过 _align_seqs 将每个序列填充到 BT 整数倍,利用 padding token 的零值特性使其在数学上透明(L 矩阵块对角解耦)。
2.0 公共 API
入口函数:chunk_kda() @ tops/ops/kda/chunk.py:49
@jax.custom_vjp
def chunk_kda(
q, k, v, g, beta, # [B, T, H, D]
A_log=None, dt_bias=None,
scale=None,
initial_state=None, # [N, H, K, V] 用于 varlen
output_final_state=False,
use_qk_l2norm_in_kernel=False,
use_gate_in_kernel=False,
segment_ids=None, # [B, T] 或 [1, T], varlen 模式入口
safe_gate=False,
lower_bound=None,
disable_recompute=False,
cp_context=None,
transpose_state_layout=False,
chunk_size=64,
**kwargs,
):自定义 VJP 绑定 @ chunk.py:428:
chunk_kda.defvjp(_chunk_kda_fwd_custom, _chunk_kda_bwd_custom)Varlen 逻辑 chunk.py:88:当 segment_ids is not None 时,调用 segment_ids_to_seqlens() 提取 cu_seqlens;也可以通过 kwargs["cu_seqlens"] 直接传入。下面 Pallas kernel 分析默认讨论 varlen 路径,即 cu_seqlens is not None,非 varlen 只作为共享 wrapper 的背景,不单独展开。
2.1 Forward 调用链(默认 varlen)
chunk_kda() chunk.py:49
└── custom_vjp:
_chunk_kda_fwd_custom() chunk.py:129
│
│ [Varlen 预处理]
│ ├── segment_ids → cu_seqlens (如需要) utils.py:193
│ ├── L2 norm on q/k (如需要)
│ ├── _align_seqs() → 序列对齐到 BT 倍数 chunk_fwd.py:566
│ ├── gate padding 修复: g[padding] = -1e4 chunk.py:176-185
│ └── compute_padded_cu_seqlens() tops/utils.py:130
│
└── chunk_kda_fwd(_skip_align=True, cu_seqlens=aligned_cu)
│ chunk_fwd.py:623
│
├── [Stage 1+2] kda_fwd_intra_fused() chunk_intra_fwd_fused.py:586
│ │
│ │ bf16 varlen 路径:
│ └── kda_fwd_intra_fused_varlen() chunk_intra_fwd_fused.py:438
│ └── Pallas #1: _fused_gate_intra_kernel chunk_intra_fwd_fused.py:37
│ 功能: 门控激活 + cumsum + BC=16 Aqk/L + Neumann 求逆
│ Grid: (1, H, NC)
│ BlockSpec: 每块 [1, 1, BT, D]
│ 返回: w, u, qg, kg, Aqk, Akk, g_cumsum
│
│ fp32 varlen fallback:
│ └── kda_gate_chunk_cumsum() → kda_fwd_intra()
│ 分别做 S1 和 S2 (不做 BC=16 融合)
│
└── [Stage 3+4, varlen]
chunk_kda_fwd_h_o_varlen() chunk_fwd.py:359
└── Pallas #2: _chunk_kda_fwd_h_o_varlen_kernel chunk_fwd.py:247
功能: 状态传播 + 输出计算 (融合)
Grid: (H, NT), PrefetchScalarGridSpec(num_scalar_prefetch=2)
Scratch: VMEM (K_PADSIZE, V_ALIGNED)
BlockSpec: 每块 [1, 1, BT, D]
返回: o [1, T, H, V], final_state [N, H, K, V]
[Varlen 后处理]
└── _unalign_output() → 反向 gather 回原始 T chunk_fwd.py:6012.2 Backward 调用链(默认 varlen)
_chunk_kda_bwd_custom() chunk.py:268
│
│ [Varlen 预处理]
│ ├── _align_seqs(do) → 对齐梯度
│ ├── 转置所有张量 [B,T,H,X] → [H,B,T,X]
│ └── 计算 aligned cu_seqlens / chunk_indices
│
└── chunk_kda_bwd(cu_seqlens=aligned_cu, chunk_indices=...) chunk_bwd.py:815
│
├── [S0] 重计算前向中间结果
│ ├── kda_gate_chunk_cumsum() (如 use_gate_in_kernel)
│ │ └── chunk_local_cumsum_vector() cumsum.py:492
│ │ └── Pallas #8: _chunk_cumsum_kernel_varlen cumsum.py:314
│ ├── recompute_w_u_fwd() (CPU reference)
│ └── chunk_gated_delta_rule_fwd_h() chunk_delta_h.py:844
│ └── Pallas #3: _chunk_gated_delta_rule_fwd_varlen_kernel chunk_delta_h.py:188
│ Grid: (H, NT), PrefetchScalarGridSpec(2)
│ Scratch: VMEM (K_PADSIZE, V_ALIGNED)
│
├── [S1] dAqk + dv
│ chunk_kda_bwd_dAv_kernel() chunk_bwd.py:470
│ └── Pallas #4: _chunk_kda_bwd_dAv_kernel chunk_bwd.py:401
│ Grid: (BH * NT,), 1D 扁平化
│ 计算: dA = do @ v^T * scale, dv = A_masked^T @ do
│
├── [S2] 反向状态传播
│ chunk_gated_delta_rule_bwd_dhu_kernel() chunk_bwd.py:657
│ └── Pallas #5: _chunk_gated_delta_rule_bwd_dhu_kernel chunk_bwd.py:574
│ Grid: (H, NT), PrefetchScalarGridSpec(2)
│ Scratch: VMEM (K, V)
│ Varlen 动态切片: dht/dh0 按 segment 切片避免加载全 N 维度
│ 反向迭代: i_t = NT - 1 - i_c
│
├── [S3] WY 融合反向
│ chunk_kda_bwd_wy_dqkg_fused_kernel() chunk_bwd.py:220
│ └── Pallas #6: _chunk_kda_bwd_wy_dqkg_fused_kernel chunk_bwd.py:25
│ Grid: (BH * NT,), 展开 K/V 循环
│ 输入11个, 输出6个: dq, dk, dv2, dg, db, dA
│
├── [S4] Intra-chunk 反向
│ chunk_kda_bwd_intra() chunk_intra.py:588
│ └── Pallas #7: kda_intra_chunk_bwd_kernel_subchunk chunk_intra.py:306
│ Grid: (H, B, NT), BC=16 子块分解
│ 累加 inter 和 intra 贡献
│
├── [S5] 反向累积和
│ chunk_local_cumsum_vector(reverse=True) cumsum.py:492
│ └── Pallas #8: _chunk_cumsum_kernel_varlen cumsum.py:314
│ Hillis-Steele 扫描 + fori_loop over chunks
│
└── 门控反向 (如需要)
kda_gate_bwd() gate.py:95
返回: dq, dk, dv, dg, db, dh0, dA, dbias
[Varlen 后处理]
├── 转置 [H,B,T,X] → [B,T,H,X]
├── L2 norm backward (如需要)
└── _unalign_output(dq, dk, dv, dg, db)2.3 Pallas Varlen Kernel 汇总(默认路径)
| # | Kernel 函数 | 文件:行 | 调用者 | 阶段 |
|---|---|---|---|---|
| 1 | _fused_gate_intra_kernel | chunk_intra_fwd_fused.py:37 | kda_fwd_intra_fused_varlen() | FWD S1+S2 |
| 2 | _chunk_kda_fwd_h_o_varlen_kernel | chunk_fwd.py:247 | chunk_kda_fwd_h_o_varlen() | FWD S3+S4 |
| 3 | _chunk_gated_delta_rule_fwd_varlen_kernel | chunk_delta_h.py:188 | _chunk_gated_delta_rule_fwd_varlen() | BWD S0 |
| 4 | _chunk_kda_bwd_dAv_kernel | chunk_bwd.py:401 | chunk_kda_bwd_dAv_kernel() | BWD S1 |
| 5 | _chunk_gated_delta_rule_bwd_dhu_kernel | chunk_bwd.py:574 | chunk_gated_delta_rule_bwd_dhu_kernel() | BWD S2 |
| 6 | _chunk_kda_bwd_wy_dqkg_fused_kernel | chunk_bwd.py:25 | chunk_kda_bwd_wy_dqkg_fused_kernel() | BWD S3 |
| 7 | kda_intra_chunk_bwd_kernel_subchunk | chunk_intra.py:306 | _kda_intra_chunk_bwd_subchunk_pallas() | BWD S4 |
| 8 | _chunk_cumsum_kernel_varlen | cumsum.py:314 | _chunk_local_cumsum_pallas() | BWD S0/S5 |
2.4 关键工具函数
| 函数 | 文件:行 | 用途 |
|---|---|---|
_segment_ids_to_packed() | chunk_fwd.py:203 | 多 batch segment_ids → 平坦化 B=1 布局 |
_align_seqs() | chunk_fwd.py:566 | 每个序列对齐到 BT 倍数, 零填充 |
_unalign_output() | chunk_fwd.py:601 | 反向 gather, 丢弃 padding 输出 |
prepare_chunk_indices() | tops/utils.py:255 | 返回 [NT, 2] 的 (seq_id, chunk_id) |
compute_padded_cu_seqlens() | tops/utils.py:130 | 对齐后的 cu_seqlens |
segment_ids_to_seqlens() | tops/utils.py:193 | 1D segment_ids → cu_seqlens |
_build_chunk_map() | common/chunk_h.py:13 | searchsorted 建立 chunk→seq 映射 |
chunk_local_cumsum_vector() | common/cumsum.py:492 | 门控 chunk 内累积和分发器 |
2.5 _align_seqs — Padding 机制
_align_seqs() @ chunk_fwd.py:566-598:
def _align_seqs(tensors_4d, tensors_3d, cu_seqlens, align):
N = cu_seqlens.shape[0] - 1
T_old = tensors_4d[0].shape[1]
seg_lens = cu_seqlens[1:] - cu_seqlens[:-1] # 如 [100, 50]
padded_lens = ((seg_lens + align - 1) // align) * align # → [128, 64]
padded_cu = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32),
jnp.cumsum(padded_lens)]) # → [0, 128, 192]
T_new = ((T_old + N * (align - 1) + align - 1) // align) * align
# gather_idx 映射: 真实位置→原位置, padding 位置→T_old (零填充区)
gather_idx = jnp.full(T_new, T_old, dtype=jnp.int32)
gather_idx = jax.lax.fori_loop(0, N, _build_gather, gather_idx)
# 先 pad 零, 再 gather
def repack_4d(t):
return jnp.pad(t, ((0,0), (0, T_new-T_old), (0,0), (0,0)))[:, gather_idx]结果:padding 位置的 q=k=v=beta=0。
2.6 门控中和修复
chunk.py:176-185:
if use_gate_in_kernel:
orig_lens = _orig_cu_seqlens[1:] - _orig_cu_seqlens[:-1]
aligned_starts = cu_seqlens[:-1]
pos = jnp.arange(T)
in_range = (pos[None, :] >= aligned_starts[:, None]) & (
pos[None, :] < (aligned_starts + orig_lens)[:, None]
)
valid_mask = in_range.any(axis=0) # [T], True=真实 token
g = jnp.where(valid_mask[None, :, None, None], g, -1e4)_align_seqs 将 gate 填充为 0,但 softplus(0 + dt_bias) 可能非零。将 padding 位置的 g 设为 -1e4 确保 softplus(-1e4 + dt_bias) ≈ 0。
2.7 _unalign_output — 输出反对齐
_unalign_output() @ chunk_fwd.py:601-620:
def _unalign_output(o, orig_cu_seqlens, aligned_cu_seqlens, T_out):
N = orig_cu_seqlens.shape[0] - 1
safe_default = aligned_cu_seqlens[-1] # 尾部零位置,确保读取零而非位置0
gather_idx = jnp.full(T_out, safe_default, dtype=jnp.int32)
gather_idx = jax.lax.fori_loop(0, N, _build_gather, gather_idx)
return o[:, gather_idx] # [1, T_out, H, V]关键:safe_default = aligned_cu_seqlens[-1] 指向对齐后所有序列末尾的 padding 区域(保证为零),避免之前 safe_default=0 的 bug(会重复读取位置 0 的输出)。
2.8 具体示例
以 seq_lens = [100, 50], BT = 64 为例:
# === _align_seqs 之后 ===
# T_old = 150
# padded_lens = [128, 64] (BT=64 对齐)
# padded_cu[-1] = 192 (128 + 64, 有效对齐后总长度)
# T_new = 320 ((_align_seqs 内部申请的安全物理上界))
# padded_cu = [0, 128, 192]
#
# 布局:
# positions 0..99: seq0 真实 (原 0..99)
# positions 100..127: seq0 padding (零)
# positions 128..177: seq1 真实 (原 100..149)
# positions 178..191: seq1 padding (零)
# positions 192..319: 安全尾部 padding (零, 不属于任何真实 seq chunk)
#
# Stage 1+2 fused intra 的张量 chunk 结构 (NC = T_new / BT = 320/64 = 5):
# chunk 0: positions 0..63 = seq0 真实
# chunk 1: positions 64..127 = seq0 真实(36) + padding(28)
# chunk 2: positions 128..191 = seq1 真实(50) + padding(14)
# chunk 3: positions 192..255 = 安全尾部 padding (纯零)
# chunk 4: positions 256..319 = 安全尾部 padding (纯零)
#
# Stage 3+4 fused state/output 的 varlen chunk 结构 (NT = len(chunk_indices) = 3):
# chunk_indices 只覆盖 padded_cu[-1] = 192 以内的 3 个有效对齐 chunk
# Stage 1+2 产生的 chunk 3/4 是纯零安全尾部,不会被 Stage 3+4 读取这里要区分两个长度和两个 chunk 计数:T_new 是 _align_seqs 为 gather_idx 和 jnp.pad 申请的安全上界,Stage 1+2 会按 T_new / BT reshape 成完整 chunk;padded_cu[-1] 才是所有序列按 BT 对齐后的有效总长度,Stage 3+4 的 chunk_indices 只覆盖这部分有效 chunk。超过 padded_cu[-1] 的尾部位置由 jnp.pad(..., 0) 提供,是纯零安全尾部,不会被 Stage 3+4 读取,也不会被 _unalign_output() gather 回真实输出。
2.9 Padding 在求解中的数学解耦
Stage 1+2:L 矩阵块对角结构
_fused_gate_intra_kernel @ chunk_intra_fwd_fused.py:37 中的关键推导:
门控累积和:
g_f32 = -exp(A_val) * softplus(g_f32 + dt_bias) # padding: g=-1e4 → g_act≈0
g_cumsum = tril @ g_f32 # padding 行: 平坦继承最后真实 token 的 cumsumL 矩阵(BC=16 子块分解):
# L[i,j] = beta_i * (k_eg_i @ k_eng_j^T)
# k_eg = k * exp2(g_cumsum - g_ref)
# k_eng = k * exp2(g_ref - g_cumsum)对于 padding 位置:
- padding 行 i:
k_i=0且beta_i=0→ L[padding, :] = 0 - padding 列 j:
k_j=0→k_i @ k_j^T = 0→ L[:, padding] = 0
L 矩阵呈现块对角结构:
L = [L_real, 0 ] ← BT_real × BT_real 有效块
[0, 0 ] ← BT_pad × BT_pad 零块求解 (I+L)x = rhs:
- padding 行: RHS =
[v_beta=0, k_eg_beta=0, I_padding_row] (I + 0)^{-1} @ [0, 0, I_row] = [0, 0, I_row]- → u[padding]=0, w[padding]=0, A_inv[padding]=one_hot(padding)
真实行的解 (I + L_real)^{-1} @ rhs_real 与 padding 完全无关。
Stage 3+4:Padding 透明传播
_chunk_kda_fwd_h_o_varlen_kernel @ chunk_fwd.py:247 中的关键推导:
v_new 计算(Stage 3 修正值):
b_v_new = b_u - b_w @ b_h # padding: 0 - 0 @ h = 0输出计算(Stage 4):
b_qg = b_q * exp2(g - g_ref) # padding: 0 * exp2(...) = 0
b_o_inter = scale * b_qg @ (b_h * exp2(g_ref)) # = 0
b_o_intra = tril(A) @ b_v_new # = tril(A) @ 0 = 0状态更新:
b_gk_last = gk_ref[BT-1] # padding chunk: = g_cumsum[last_real]
scratch *= exp2(b_gk_last) # 按真实衰减率衰减
scratch += b_kg^T @ b_v_new # padding: 0^T @ 0 = 0, 状态不变2.10 逐 Token 追踪(chunk1: 36真实 + 28padding)
Token | 64 65 ... 99 | 100 101 ... 127
--------|--------------|-----------------
类型 | 真实 | padding
g_orig | 实值 | -1e4 (修复后)
g_act | 实值 | softplus(-1e4) ≈ 0
g_cumsum| 递增 | 平坦 = cumsum[99]
q,k,v | 实值 | 0 (align 填充)
beta | 实值 | 0 (align 填充)
--------|--------------|-----------------
L 矩阵 | L[0:36, 0:36] 实值 | L[36:64, :] = 0
| L[0:36, 36:64] = 0 | L[:, 36:64] = 0
RHS | v_beta, k_eg_beta 实值 | 0, 0
--------|--------------|-----------------
u | 实值 (来自求解) | 0
w | 实值 | 0
kg | 实值 | 0
qg | 实值 | 0
--------|--------------|-----------------
v_new | u - w@h 实值 | 0 - 0@h = 0
o | inter + intra 实值 | 0 + 0 = 0
状态 | h *= exp2(cumsum[99]) | h *= exp2(cumsum[99])
更新 | h += kg^T @ v_new 实值 | h += 0 = h2.11 当前 Pallas Varlen 实现的 padding / 越界处理影响
本表默认按 Pallas varlen 路径审计。当前 Pallas varlen 路径和 FLA 的核心差异是:Pallas 不依赖 kernel 内 boundary_check 给越界 lane 兜底,而是在进入 Pallas kernel 之前把 T/K/V 维度补到 kernel 需要的安全形状。Pallas varlen kernel 中即使设置了 disable_bounds_checks=True,读写对象也是已经 jnp.pad 或 _align_seqs 过的张量,因此不会出现 FLA tl.load(block_ptr, boundary_check=...) 那种 undefined lane 语义。
| API / 阶段 | 代表 kernel / 函数 | padding / 越界来源 | 当前处理 | mask / pad 后的值 | 对结果的影响 |
|---|---|---|---|---|---|
chunk_kda() varlen 入口 | _segment_ids_to_packed() / segment_ids_to_seqlens() | 多 batch segment_ids 被展平到 B=1 | 只做 layout pack 和 cu_seqlens 生成,不引入 padding 值 | 没有新增 padding lane | 无风险 |
| varlen 前向对齐 | _align_seqs() | 每个序列长度不是 BT 整数倍 | 先 jnp.pad(..., 0),再用 gather_idx 把每段放到 BT 对齐位置 | q/k/v/beta 的 padding token 变成 0;g 初始 padding 为 0 | 无风险,但 use_gate_in_kernel=True 还依赖下一行 gate 修复 |
| gate padding 修复 | _chunk_kda_fwd_custom() / chunk_kda_fwd() | _align_seqs() 会把 raw gate padding 成 0 | 对 padding 位置执行 jnp.where(valid_mask, g, -1e4) | padding g 变成 -1e4,softplus(-1e4 + dt_bias) ≈ 0 | 无风险 |
use_qk_l2norm_in_kernel=True | tops.cpu.ops.common.l2norm_fwd/bwd | K 非 2 的幂或非 block 对齐 | 当前 Pallas 路径在 JAX/CPU reference 上做整张量 norm,不使用 block-pointer load | 没有 undefined feature tail;padding token 的 q/k 已经是 0 | 无风险 |
| FWD S1+S2 fused intra | _fused_gate_intra_kernel / kda_fwd_intra_fused_varlen() | T tail 来自 _align_seqs();K/V 维按真实维度进入该 kernel | 要求 T % BT == 0;直接 reshape 成完整 chunk;因果上三角用零矩阵或 mask 清掉 | padding token 的 q/k/v/beta=0;padding gate 激活约为 0;非因果 Aqk/Akk entry 为 0 | 无风险 |
| FWD S3+S4 fused state/output | _chunk_kda_fwd_h_o_varlen_kernel | K tail、V tail、额外安全 T chunk、空序列 final_state | wrapper 先把 K/V 补到 128 对齐,T 额外补 BT;输出再裁回真实 K/V/T;empty seq 用 jnp.where 填初始状态或 0 | K/V/T padding 都是 0;empty final_state 变成 initial_state 或 0 | 无风险 |
| BWD 对齐 | _chunk_kda_bwd_custom() + _align_seqs(do) | do 原始长度不是 BT 对齐 | backward 开始先把 do 对齐到和 forward residual 一致的 T | padding do 变成 0 | 无风险 |
| BWD S0 重计算 | kda_gate_chunk_cumsum() / recompute_w_u_fwd() / chunk_gated_delta_rule_fwd_h() | aligned padding token、K/V 对齐 padding | 使用 forward residual 的对齐张量;h chunk 数不足时 jnp.pad(..., 0) | padding token 继续满足 q/k/v/beta/do=0;padding h chunk 为 0 | 无风险 |
| BWD S1 dAqk + dv | _chunk_kda_bwd_dAv_kernel | V tail、A 上三角、padding token | block_V 默认等于 V,若自定义必须整除;A 用 causal mask 清上三角 | 非因果 A/dA entry 变成 0.0;padding v/do 为 0 | 无风险 |
| BWD S2 reverse state | _chunk_gated_delta_rule_bwd_dhu_kernel | reverse chunk 顺序、序列边界、dht=None | chunk_to_seq 定位 seq;在 t0 + BT >= eos 重置 scratch;dht=None 时使用全 0 | 无 dht 时 final-state 梯度为 0;序列间 scratch 不串联 | 无风险 |
| BWD S3 WY fused | _chunk_kda_bwd_wy_dqkg_fused_kernel | K/V tile tail、A 上三角、padding token | K % BK == 0、V % BV == 0;默认 BK=K、BV=V;m_lower 清非严格下三角 | 非严格下三角外的 dAkk entry 为 0.0;padding token 输入为 0 | 无风险 |
| BWD S4 intra backward | kda_intra_chunk_bwd_kernel_subchunk | A 上三角、非严格下三角、aligned padding token | dAqk 用 causal mask;dAkk 用 strict-lower mask;T 已经 BT 对齐 | 非因果 dAqk 为 0.0;非严格下三角外 dAkk 为 0.0 | 无风险 |
| BWD S5 reverse cumsum | _chunk_cumsum_kernel_varlen | S tail、BH tail、T 安全尾块、padding token | S/BH/T 都先 jnp.pad(..., 0);输出裁回 [:BH, :T, :S];最后再 _unalign_output | S/BH/T 额外 lane 为 0;padding token 输出最终被丢弃 | 无风险 |
| gate backward | kda_gate_bwd() | padding g_org 和可能存在的 padding dyg | g_org padding 已经是 -1e4;dg/dbias/dA 按激活函数导数计算 | padding gate 的 softplus/sigmoid 导数约为 0;padding dg 最终被 _unalign_output 丢弃 | 无风险 |
当前 Pallas varlen 扫描结论:没有发现类似 FLA block-pointer boundary_check 默认 undefined 的数值风险。需要维持正确性的关键条件是:所有 varlen 调用必须继续走 _align_seqs() / aligned cu_seqlens;use_gate_in_kernel=True 时必须保留 padding g=-1e4 的修复;新增 Pallas varlen kernel 如果使用 disable_bounds_checks=True,必须先在 wrapper 层把 T/K/V 和额外安全 chunk 明确补成确定值。
三、两种策略对比
| 维度 | FLA (Triton, GPU) | Pallas (TPU) |
|---|---|---|
| Padding 策略 | 无物理 padding | _align_seqs 物理填充到 BT 倍数 |
| 边界处理 | 多重条件检查 + 循环边界限制;boundary_check load 默认 undefined | L 矩阵数学解耦 + 门控中和 |
| 索引机制 | prepare_chunk_indices → 每个 kernel 内动态解码 | _align_seqs → 统一的 transpose+reshape |
| Chunk 结构 | 最后一个 chunk 只有部分 token | 所有 chunk 都是 BT 个 token |
| 计算量 | 最后一个 chunk 做更少工作 | 所有 chunk 做相同工作(padding 部分为无效计算) |
| 内存 | T 不变 | T 增大(对齐后) |
| 正确性保证 | T 维 varlen tail 由边界逻辑隔离;K/V/D 维归约 tail 需要对齐或显式补零 | k=0 → L=0 → 解耦 |
| Kernel 复杂度 | 每个 kernel 都需要 IS_VARLEN 路径 | 默认 varlen wrapper 先对齐;kernel 主要处理完整 BT chunk |
| 硬件适配 | GPU (Triton) | TPU (K/V 128 对齐, PrefetchScalarGridSpec) |
四、结论
Pallas 实现:padding 不影响真实 token 的计算结果。
原因可归结为一条数学不变量:
k=0→ L 矩阵的块对角结构 → 真实 token 与 padding token 在求解中完全解耦
具体来说:
- k=0 且 beta=0:
_align_seqs的零填充使 L 矩阵中与 padding 相关的行/列全为零 - L_real 独立:真实 token 的求解方程
(I + L_real) @ X_real = RHS_real与 padding 完全无关 - g=-1e4 中和:门控修复确保
g_act ≈ 0,padding 不贡献衰减 - 状态透明:padding 处 kg=0, v_new=0,状态传播不受影响
- 输出正确裁剪:
_unalign_output丢弃 padding 输出
FLA 实现:T 维 varlen tail 通常不影响真实 token,但 block-pointer load 不是自动补零。
FLA 的工程策略是“不构造物理 padding token”,并在每个 Triton kernel 中解码当前序列的 bos/eos/T。这能避免 Pallas 那种对齐后 padding token 参与完整 chunk 计算的问题。
但从 tl.load 语义看,FLA 还隐含一个额外条件:所有参与 tl.sum / tl.dot 的 K/V/D 尾部 lane 必须是有效值、被 mask 成 0,或通过 padding_option="zero" 明确补零。否则 boundary_check 返回的 undefined lane 可能污染真实结果。尤其是 use_qk_l2norm_in_kernel=True 时的 l2norm_fwd_kernel / l2norm_bwd_kernel 小 D 路径,非 2 的幂 head dim 会把 undefined feature tail 纳入 norm reduction。
两者的差异在于工程策略,而非数学正确性。
- FLA 更"干净":不引入无用数据,每个 chunk 只做必要计算,但每个 kernel 都需要
IS_VARLEN分支,并且 block-pointer 越界 load 需要单独审计 - Pallas 更"规整":通过前置对齐统一 chunk 大小,适配 TPU 的硬件对齐要求(128 对齐、PrefetchScalarGridSpec),代价是部分无效计算和 T 维度膨胀
五、实证验证
5.1 Pallas Varlen E2E 测试
PALLAS_INTERPRET=1 JAX_PLATFORMS=cpu uv run pytest tests/ops/kda/test_varlen_e2e.py -v结果:44 passed 全部通过。
测试覆盖场景:
| 测试函数 | 用例数 | 验证内容 |
|---|---|---|
test_chunk_kda_fwd_varlen | 19 | 非 BT 对齐 forward(多种 seq_lens、dtype、head_dim 组合) |
test_chunk_kda_varlen_bwd | 19 | 非 BT 对齐 backward |
test_chunk_kda_varlen_bwd_no_oob_tail_chunks | 3 | T_pad >> sum(seq_lens) 时 tail chunks 不越界 |
test_chunk_kda_varlen_fwd_padding_tail_not_duplicate | 1 | _unalign_output 使用 safe_default 而非位置 0 |
test_chunk_kda_varlen_empty_seq_final_state | 2 | N_pad > 真实序列数时的 empty seq 处理 |
覆盖的组合:
[30, 50]、[45, 80, 20]、[100]等非 BT=64 对齐序列T_pad > sum(seq_lens)真实 padding 场景N_pad=8仅 2 个真实序列(尾随空序列)initial_state=None和非 None- bf16 精度
5.2 Segment E2E 测试
PALLAS_INTERPRET=1 JAX_PLATFORMS=cpu uv run pytest tests/ops/kda/test_segment_e2e.py -v5.3 FLA 库测试
FLA 库不引入物理 padding 数据,每个 chunk 仅包含真实 token。主 KDA varlen 路径的 T 维 tail 由边界逻辑隔离;但还应补充以下回归用例来覆盖 boundary_check 的 undefined lane 风险:
| 场景 | 预期 |
|---|---|
use_qk_l2norm_in_kernel=True 且 K 非 2 的幂(如 96/192) | l2norm_fwd/bwd 应与 PyTorch reference 一致;若不一致,需给 block-pointer load 补 padding_option="zero" |
KDA forward 中 K 非 BK 对齐 | Aqk/Akk/w/kg/qg/o 不应受 K tail undefined lane 影响 |
KDA backward 中 V 非 BV 对齐 | dA/dv/dq/dk/dg/db 不应受 V tail undefined lane 影响 |
reverse cumsum 最后一个 varlen chunk 不满 BT | tail lane 不应进入真实 token 的 reverse prefix 结果 |