Skip to content

pallas-kernel

pallas-kernel(PyPI 包名 tops)是 primatrix 团队的 TPU Pallas kernel 集合,为 MaxText、sglang-jax 等上层项目提供高性能、可测试、可复用的 Pallas kernel 实现。

核心目标

  • 高性能:每个 kernel 经过 Roofline 分析和 profiling 验证,逼近硬件理论上限
  • 可测试:每个 kernel 提供与 NumPy/PyTorch 参考实现的数值精度对比测试
  • 可复用:统一的 API 设计,支持通过 pip install 作为依赖引入

技术栈

组件版本
Python≥ 3.12
JAX≥ 0.7.0
Pallas APIjax.experimental.pallas.tpu
构建工具uv

目标 TPU 硬件

项目主要面向以下 TPU 代次,kernel 文档中需标注各代次的兼容性:

TPU 代次说明
v4基础支持
v5e推理优化
v6e当前开发主力
v7x高端训练

文档导航

项目规范

文档说明
硬件约束与 API 限制TPU 硬件架构、Pallas API 硬性约束、内存对齐规则、性能陷阱
代码规范项目结构、命名约定、代码风格、kernel 实现规范
CI/CD 流水线持续集成阶段、TPU 集成测试、pre-commit 钩子、版本发布
开发流程分支策略、kernel 开发生命周期、PR 规范、测试要求
Benchmark 规范Roofline 分析方法、benchmark 编写规范、profiling 工具链
Kernel 文档模板新增 kernel 时复制使用的标准化文档模板

性能工程参考

以下文档基于 How to Scale Your Model 系列整理,提供 kernel 开发所需的性能分析背景知识。

文档说明
Roofline 分析深度指南算术强度公式、临界 batch size、tiling 强度、多芯片通信 Roofline
TPU 硬件规格参考跨代次 TPU 计算/内存/互联规格表、VPU 详细规格、GPU 对比
Sharding 与集合通信参考Sharding 标记法、AllGather/ReduceScatter/AllReduce 开销、并行策略
Transformer 算子性能参考各组件 FLOPs 计算、注意力算术强度、Flash Attention、KV Cache、推理优化
Profiling 深度指南读懂 HLO 和 XLA Op、Tiling/Layout 标记、计算预期耗时、常见性能问题
Varlen Chunk-KDA Padding 分析FLA 与 Pallas varlen Chunk-KDA padding 策略、边界处理和正确性对比

Quick Start

环境搭建

bash
# 克隆仓库
git clone git@github.com:primatrix/pallas-kernel.git
cd pallas-kernel

# 创建虚拟环境并安装依赖
uv venv
source .venv/bin/activate
uv pip install -e ".[dev]"

运行测试

bash
# CPU 模式单元测试
pytest tests/ -v

# TPU 上运行(需要 TPU 环境)
pytest tests/ -v --tpu

运行 Benchmark

bash
python benchmarks/matmul_bench.py --shape 1024,1024,1024

与 onboarding 文档的关系

wiki 中的 Pallas Kernel 编写经验总结TPU 性能优化指南 是面向所有新人的通用入门材料。本项目文档是面向 pallas-kernel 仓库开发者的项目级规范,包含项目特有的约束、流程和标准。两者互补,建议先阅读 onboarding 文档建立基础认知。