pallas-kernel
pallas-kernel(PyPI 包名 tops)是 primatrix 团队的 TPU Pallas kernel 集合,为 MaxText、sglang-jax 等上层项目提供高性能、可测试、可复用的 Pallas kernel 实现。
核心目标
- 高性能:每个 kernel 经过 Roofline 分析和 profiling 验证,逼近硬件理论上限
- 可测试:每个 kernel 提供与 NumPy/PyTorch 参考实现的数值精度对比测试
- 可复用:统一的 API 设计,支持通过
pip install作为依赖引入
技术栈
| 组件 | 版本 |
|---|---|
| Python | ≥ 3.12 |
| JAX | ≥ 0.7.0 |
| Pallas API | jax.experimental.pallas.tpu |
| 构建工具 | uv |
目标 TPU 硬件
项目主要面向以下 TPU 代次,kernel 文档中需标注各代次的兼容性:
| TPU 代次 | 说明 |
|---|---|
| v4 | 基础支持 |
| v5e | 推理优化 |
| v6e | 当前开发主力 |
| v7x | 高端训练 |
文档导航
项目规范
| 文档 | 说明 |
|---|---|
| 硬件约束与 API 限制 | TPU 硬件架构、Pallas API 硬性约束、内存对齐规则、性能陷阱 |
| 代码规范 | 项目结构、命名约定、代码风格、kernel 实现规范 |
| CI/CD 流水线 | 持续集成阶段、TPU 集成测试、pre-commit 钩子、版本发布 |
| 开发流程 | 分支策略、kernel 开发生命周期、PR 规范、测试要求 |
| Benchmark 规范 | Roofline 分析方法、benchmark 编写规范、profiling 工具链 |
| Kernel 文档模板 | 新增 kernel 时复制使用的标准化文档模板 |
性能工程参考
以下文档基于 How to Scale Your Model 系列整理,提供 kernel 开发所需的性能分析背景知识。
| 文档 | 说明 |
|---|---|
| Roofline 分析深度指南 | 算术强度公式、临界 batch size、tiling 强度、多芯片通信 Roofline |
| TPU 硬件规格参考 | 跨代次 TPU 计算/内存/互联规格表、VPU 详细规格、GPU 对比 |
| Sharding 与集合通信参考 | Sharding 标记法、AllGather/ReduceScatter/AllReduce 开销、并行策略 |
| Transformer 算子性能参考 | 各组件 FLOPs 计算、注意力算术强度、Flash Attention、KV Cache、推理优化 |
| Profiling 深度指南 | 读懂 HLO 和 XLA Op、Tiling/Layout 标记、计算预期耗时、常见性能问题 |
| Varlen Chunk-KDA Padding 分析 | FLA 与 Pallas varlen Chunk-KDA padding 策略、边界处理和正确性对比 |
Quick Start
环境搭建
bash
# 克隆仓库
git clone git@github.com:primatrix/pallas-kernel.git
cd pallas-kernel
# 创建虚拟环境并安装依赖
uv venv
source .venv/bin/activate
uv pip install -e ".[dev]"运行测试
bash
# CPU 模式单元测试
pytest tests/ -v
# TPU 上运行(需要 TPU 环境)
pytest tests/ -v --tpu运行 Benchmark
bash
python benchmarks/matmul_bench.py --shape 1024,1024,1024与 onboarding 文档的关系
wiki 中的 Pallas Kernel 编写经验总结 和 TPU 性能优化指南 是面向所有新人的通用入门材料。本项目文档是面向 pallas-kernel 仓库开发者的项目级规范,包含项目特有的约束、流程和标准。两者互补,建议先阅读 onboarding 文档建立基础认知。