Skip to content

Kernel 文档模板

新增 kernel 时,复制本模板并填写各节内容。模板中 {placeholder} 为需要替换的占位符。


概述

  • 算法简述
  • 适用场景{在哪些模型/模块中使用,如 MaxText GLA attention、sglang-jax MoE routing}
  • 性能定位:{compute-bound / memory-bound},Arithmetic Intensity ≈

算法设计

数学公式

{核心计算的数学表达式,使用 LaTeX 语法}

O=softmax(QKTdk)V

计算流程

与标准实现的差异

实现方案

Grid/Block 划分

python
grid = ({grid_dims})
dimension_semantics = ({semantics})
block_size = {value}  # 选择原因:{解释}
Grid 维度semantics说明
batchBparallel无依赖
headHparallel无依赖
seq_chunkT/chunk_sizearbitrarychunk 间有状态依赖

内存布局与数据搬运

text
HBM: Q[B,T,H,K], K[B,T,H,K], V[B,T,H,V]

  │ DMA (BlockSpec 自动管理 / 手动)

VMEM: q_tile[chunk,K], k_tile[chunk,K], v_tile[chunk,V]

  │ MXU / Vector

VMEM: o_tile[chunk,V]

  │ DMA

HBM: O[B,T,H,V]

计算单元分配

计算单元说明
QK^T matmulMXU主要计算
softmax / expVector + EUP非线性
索引计算Scalar ALU已优化(SMEM 查找表 / unroll)

Forward / Backward

  • Forward
  • Backward:{反向实现要点,包含 custom_vjp 设计}
  • Fused 路径

关键优化点

API 接口

python
def {kernel_name}(
    {param1}: jax.Array,  # shape [{shape}], dtype {dtype}
    {param2}: jax.Array,  # shape [{shape}], dtype {dtype}
    *,
    {kwarg1}: {type} = {default},
) -> {return_type}:
    """
    {一句话说明}

    Args:
        {param1}: {说明},shape 约束:{约束}
        {param2}: {说明}

    Returns:
        {说明},shape [{shape}]
    """

使用示例

python
import jax
import jax.numpy as jnp
from tops.kernels.{name} import {kernel_name}

key = jax.random.PRNGKey(0)
{input_init_code}

output = {kernel_name}({args})

测试

项目
参考实现
rtol (float32)
atol (float32)
rtol (bfloat16)
atol (bfloat16)

覆盖的边界条件

  • [ ] 最小对齐尺寸 (128×128)
  • [ ] 非对齐维度 (127, 255, 513)
  • [ ] 极大输入(接近 VMEM 上限)
  • [ ] 特殊值 (NaN, Inf, 零矩阵)
  • [ ] 所有支持的 dtype

Benchmark

测试矩阵

ShapedtypePallas (ms)Native (ms)SpeedupTFLOPS
bf16
bf16
f32

Roofline 分析

指标
Arithmetic Intensity{AI} FLOP/Byte
理论峰值 TFLOPS
实测 TFLOPS
MFU{mfu}%
Bound 类型

已知限制

限制说明
支持的 dtype
输入形状限制{如:T 必须是 chunk_size 的倍数}
TPU 版本兼容性{v4: 支持, v5e: 支持, v6e: 验证中, v7x: 支持}