性能优化工作拆解
ALModel 17B XL 在 TPU v7x 上的训练性能分析与优化工作全局视图。
理论分析
- 分析框架:性能理论分析框架
- 分析结果:理论分析结果
- 分析工具:ant-pretrain/tools
理论 MFU > 20% 的并行/重计算策略
下表为早期分析版本(XLA reserve=20%),完整最新版本见理论分析结果(XLA reserve=30%)。
| # | TP | DP | PP | EP | FSDP | CP | Remat | MB | GA | W(GB) | O(GB) | G(GB) | FBuf | Act(GB) | Rsv | Tot(GB) | Trn(PF) | Rmt% | Comp | TP | EP | FSDP+ | DP | CP | PP+ | Bub | Opt | Step(s) | MFU% | Bottleneck |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 8 | 1 | 1 | 16 | 1 | out_proj | 10 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 54.5 | 15.7 | 94.4 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 2.38 | 0.09 | 0.00 | 0.00 | 0.00 | 0.01 | 4.30 | 32.2 | FSDP |
| 2 | 1 | 4 | 1 | 1 | 32 | 1 | qkv_proj | 10 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 61.8 | 15.6 | 93.7 | 204.6 | 29 | 1.78 | 0.00 | 0.00 | 2.54 | 0.07 | 0.00 | 0.00 | 0.00 | 0.00 | 4.40 | 31.5 | FSDP |
| 3 | 1 | 1 | 1 | 1 | 128 | 1 | dot-mlp+ctx | 10 | 4 | 0.5 | 1.0 | 0.5 | 8.3 | 68.4 | 15.7 | 94.5 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 2.71 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 4.41 | 31.4 | FSDP |
| 4 | 1 | 2 | 1 | 1 | 64 | 1 | dot-mlp | 10 | 4 | 1.0 | 2.0 | 1.0 | 8.3 | 65.1 | 15.5 | 92.9 | 204.6 | 27 | 1.76 | 0.00 | 0.00 | 2.63 | 0.04 | 0.00 | 0.00 | 0.00 | 0.00 | 4.43 | 31.2 | FSDP |
| 5 | 1 | 4 | 1 | 1 | 16 | 2 | out_proj | 20 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 54.5 | 15.7 | 94.4 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 2.38 | 0.13 | 0.32 | 0.00 | 0.00 | 0.01 | 4.66 | 29.7 | FSDP |
| 6 | 1 | 1 | 1 | 1 | 16 | 8 | out_proj | 80 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 54.5 | 15.7 | 94.4 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 2.38 | 0.00 | 0.47 | 0.00 | 0.00 | 0.01 | 4.68 | 29.6 | FSDP |
| 7 | 1 | 1 | 1 | 1 | 64 | 2 | dot-mlp | 20 | 4 | 1.0 | 2.0 | 1.0 | 8.3 | 65.1 | 15.5 | 92.9 | 204.6 | 27 | 1.76 | 0.00 | 0.00 | 2.63 | 0.00 | 0.32 | 0.00 | 0.00 | 0.00 | 4.71 | 29.4 | FSDP |
| 8 | 1 | 2 | 1 | 1 | 32 | 2 | qkv_proj | 20 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 61.8 | 15.6 | 93.7 | 204.6 | 29 | 1.78 | 0.00 | 0.00 | 2.54 | 0.09 | 0.32 | 0.00 | 0.00 | 0.00 | 4.74 | 29.3 | FSDP |
| 9 | 1 | 1 | 1 | 1 | 32 | 4 | qkv_proj | 40 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 61.8 | 15.6 | 93.7 | 204.6 | 29 | 1.78 | 0.00 | 0.00 | 2.54 | 0.00 | 0.41 | 0.00 | 0.00 | 0.00 | 4.74 | 29.2 | FSDP |
| 10 | 1 | 2 | 1 | 1 | 16 | 4 | out_proj | 40 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 54.5 | 15.7 | 94.4 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 2.38 | 0.18 | 0.41 | 0.00 | 0.00 | 0.01 | 4.80 | 28.9 | FSDP |
| 11 | 4 | 32 | 1 | 1 | 1 | 1 | save_all | 1 | 160 | 16.2 | 32.5 | 16.2 | 0.0 | 10.2 | 15.0 | 90.1 | 204.6 | 0 | 1.39 | 3.40 | 0.00 | 0.00 | 0.39 | 0.00 | 0.00 | 0.00 | 0.03 | 5.21 | 26.6 | TP AR |
| 12 | 1 | 8 | 1 | 1 | 16 | 1 | dot-mlp+ctx | 8 | 5 | 4.1 | 8.1 | 4.1 | 7.9 | 54.7 | 15.8 | 94.6 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.46 | 0.09 | 0.00 | 0.00 | 0.00 | 0.01 | 5.26 | 26.3 | FSDP |
| 13 | 4 | 32 | 1 | 1 | 1 | 1 | min+ctx | 2 | 80 | 16.2 | 32.5 | 16.2 | 0.0 | 14.0 | 15.8 | 94.8 | 204.6 | 7 | 1.48 | 3.40 | 0.00 | 0.00 | 0.39 | 0.00 | 0.00 | 0.00 | 0.03 | 5.31 | 26.1 | TP AR |
| 14 | 1 | 4 | 1 | 1 | 32 | 1 | dot-mlp+ctx | 8 | 5 | 2.0 | 4.1 | 2.0 | 8.1 | 54.7 | 14.2 | 85.1 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.62 | 0.07 | 0.00 | 0.00 | 0.00 | 0.00 | 5.40 | 25.7 | FSDP |
| 15 | 1 | 2 | 1 | 1 | 64 | 1 | dot-mlp+ctx | 8 | 5 | 1.0 | 2.0 | 1.0 | 8.3 | 54.7 | 13.4 | 80.4 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.71 | 0.04 | 0.00 | 0.00 | 0.00 | 0.00 | 5.46 | 25.4 | FSDP |
| 16 | 4 | 16 | 1 | 1 | 1 | 2 | save_all | 2 | 160 | 16.2 | 32.5 | 16.2 | 0.0 | 9.2 | 14.8 | 89.0 | 204.6 | 0 | 1.39 | 3.40 | 0.00 | 0.00 | 0.38 | 0.32 | 0.00 | 0.00 | 0.03 | 5.52 | 25.1 | TP AR |
| 17 | 4 | 8 | 1 | 1 | 1 | 4 | save_all | 5 | 128 | 16.2 | 32.5 | 16.2 | 0.0 | 11.0 | 15.2 | 91.1 | 204.6 | 0 | 1.39 | 3.40 | 0.00 | 0.00 | 0.35 | 0.41 | 0.00 | 0.00 | 0.03 | 5.58 | 24.8 | TP AR |
| 18 | 4 | 16 | 1 | 1 | 1 | 2 | min+ctx | 4 | 80 | 16.2 | 32.5 | 16.2 | 0.0 | 14.0 | 15.8 | 94.8 | 204.6 | 7 | 1.48 | 3.40 | 0.00 | 0.00 | 0.38 | 0.32 | 0.00 | 0.00 | 0.03 | 5.61 | 24.7 | TP AR |
| 19 | 1 | 4 | 1 | 1 | 16 | 2 | dot-mlp+ctx | 16 | 5 | 4.1 | 8.1 | 4.1 | 7.9 | 54.7 | 15.8 | 94.6 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.46 | 0.13 | 0.32 | 0.00 | 0.00 | 0.01 | 5.62 | 24.6 | FSDP |
| 20 | 1 | 1 | 1 | 1 | 16 | 8 | dot-mlp+ctx | 64 | 5 | 4.1 | 8.1 | 4.1 | 7.9 | 54.7 | 15.8 | 94.6 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.46 | 0.00 | 0.47 | 0.00 | 0.00 | 0.01 | 5.65 | 24.5 | FSDP |
| 21 | 4 | 8 | 1 | 1 | 1 | 4 | min+ctx | 8 | 80 | 16.2 | 32.5 | 16.2 | 0.0 | 14.0 | 15.8 | 94.8 | 204.6 | 7 | 1.48 | 3.40 | 0.00 | 0.00 | 0.35 | 0.41 | 0.00 | 0.00 | 0.03 | 5.68 | 24.4 | TP AR |
| 22 | 1 | 1 | 1 | 1 | 64 | 2 | dot-mlp+ctx | 16 | 5 | 1.0 | 2.0 | 1.0 | 8.3 | 54.7 | 13.4 | 80.4 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.71 | 0.00 | 0.32 | 0.00 | 0.00 | 0.00 | 5.73 | 24.2 | FSDP |
| 23 | 1 | 2 | 1 | 1 | 32 | 2 | dot-mlp+ctx | 16 | 5 | 2.0 | 4.1 | 2.0 | 8.1 | 54.7 | 14.2 | 85.1 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.62 | 0.09 | 0.32 | 0.00 | 0.00 | 0.00 | 5.74 | 24.1 | FSDP |
| 24 | 1 | 1 | 1 | 1 | 32 | 4 | dot-mlp+ctx | 32 | 5 | 2.0 | 4.1 | 2.0 | 8.1 | 54.7 | 14.2 | 85.1 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.62 | 0.00 | 0.41 | 0.00 | 0.00 | 0.00 | 5.75 | 24.1 | FSDP |
| 25 | 1 | 2 | 1 | 1 | 16 | 4 | dot-mlp+ctx | 32 | 5 | 4.1 | 8.1 | 4.1 | 7.9 | 54.7 | 15.8 | 94.6 | 204.6 | 23 | 1.71 | 0.00 | 0.00 | 3.46 | 0.18 | 0.41 | 0.00 | 0.00 | 0.01 | 5.76 | 24.0 | FSDP |
| 26 | 4 | 4 | 1 | 1 | 1 | 8 | save_all | 10 | 128 | 16.2 | 32.5 | 16.2 | 0.0 | 10.7 | 15.1 | 90.7 | 204.6 | 0 | 1.39 | 3.40 | 0.00 | 0.00 | 0.53 | 0.47 | 0.00 | 0.00 | 0.03 | 5.81 | 23.8 | TP AR |
| 27 | 4 | 16 | 1 | 1 | 1 | 2 | dot-mlpwi | 5 | 64 | 16.2 | 32.5 | 16.2 | 0.0 | 14.9 | 16.0 | 95.8 | 204.6 | 22 | 1.68 | 3.40 | 0.00 | 0.00 | 0.38 | 0.32 | 0.00 | 0.00 | 0.03 | 5.82 | 23.8 | TP AR |
| 28 | 4 | 8 | 1 | 1 | 1 | 4 | dot-mlpwi | 10 | 64 | 16.2 | 32.5 | 16.2 | 0.0 | 14.9 | 16.0 | 95.8 | 204.6 | 22 | 1.68 | 3.40 | 0.00 | 0.00 | 0.35 | 0.41 | 0.00 | 0.00 | 0.03 | 5.88 | 23.6 | TP AR |
| 29 | 4 | 4 | 1 | 1 | 1 | 8 | min+ctx | 16 | 80 | 16.2 | 32.5 | 16.2 | 0.0 | 14.0 | 15.8 | 94.8 | 204.6 | 7 | 1.48 | 3.40 | 0.00 | 0.00 | 0.53 | 0.47 | 0.00 | 0.00 | 0.03 | 5.91 | 23.4 | TP AR |
| 30 | 4 | 4 | 1 | 1 | 1 | 8 | dot-mlpwi | 20 | 64 | 16.2 | 32.5 | 16.2 | 0.0 | 14.9 | 16.0 | 95.8 | 204.6 | 22 | 1.68 | 3.40 | 0.00 | 0.00 | 0.53 | 0.47 | 0.00 | 0.00 | 0.03 | 6.11 | 22.7 | TP AR |
| 31 | 8 | 16 | 1 | 1 | 1 | 1 | save_all | 5 | 64 | 8.1 | 16.2 | 8.1 | 0.0 | 37.9 | 14.1 | 84.4 | 204.6 | 0 | 1.39 | 4.59 | 0.00 | 0.00 | 0.19 | 0.00 | 0.00 | 0.00 | 0.02 | 6.18 | 22.4 | TP AR |
| 32 | 8 | 16 | 1 | 1 | 1 | 1 | min+ctx | 8 | 40 | 8.1 | 16.2 | 8.1 | 0.0 | 47.4 | 16.0 | 95.8 | 204.6 | 7 | 1.48 | 4.59 | 0.00 | 0.00 | 0.19 | 0.00 | 0.00 | 0.00 | 0.02 | 6.28 | 22.1 | TP AR |
| 33 | 2 | 8 | 1 | 1 | 8 | 1 | dot-mlp+ctx | 10 | 8 | 4.1 | 8.1 | 4.1 | 3.7 | 58.3 | 15.6 | 93.9 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.22 | 0.09 | 0.00 | 0.00 | 0.00 | 0.01 | 6.29 | 22.0 | TP AR |
| 34 | 8 | 8 | 1 | 1 | 1 | 2 | save_all | 10 | 64 | 8.1 | 16.2 | 8.1 | 0.0 | 35.5 | 13.6 | 81.6 | 204.6 | 0 | 1.39 | 4.59 | 0.00 | 0.00 | 0.18 | 0.32 | 0.00 | 0.00 | 0.02 | 6.48 | 21.4 | TP AR |
| 35 | 2 | 4 | 1 | 1 | 16 | 1 | dot-mlp+ctx | 10 | 8 | 2.0 | 4.1 | 2.0 | 3.9 | 58.3 | 14.1 | 84.5 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.48 | 0.07 | 0.00 | 0.00 | 0.00 | 0.00 | 6.53 | 21.2 | FSDP |
| 36 | 8 | 8 | 1 | 1 | 1 | 2 | min+ctx | 16 | 40 | 8.1 | 16.2 | 8.1 | 0.0 | 47.4 | 16.0 | 95.8 | 204.6 | 7 | 1.48 | 4.59 | 0.00 | 0.00 | 0.18 | 0.32 | 0.00 | 0.00 | 0.02 | 6.58 | 21.1 | TP AR |
| 37 | 2 | 2 | 1 | 1 | 32 | 1 | dot-mlpwi | 10 | 8 | 1.0 | 2.0 | 1.0 | 4.1 | 70.8 | 15.8 | 94.7 | 204.6 | 22 | 1.68 | 2.27 | 0.00 | 2.63 | 0.04 | 0.00 | 0.00 | 0.00 | 0.00 | 6.63 | 20.9 | FSDP |
| 38 | 2 | 1 | 1 | 1 | 64 | 1 | dot-mlpwi | 10 | 8 | 0.5 | 1.0 | 0.5 | 4.1 | 70.8 | 15.4 | 92.4 | 204.6 | 22 | 1.68 | 2.27 | 0.00 | 2.69 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 6.65 | 20.8 | FSDP |
| 39 | 2 | 4 | 1 | 1 | 8 | 2 | dot-mlp+ctx | 20 | 8 | 4.1 | 8.1 | 4.1 | 3.7 | 58.3 | 15.6 | 93.9 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.22 | 0.13 | 0.32 | 0.00 | 0.00 | 0.01 | 6.65 | 20.8 | TP AR |
| 40 | 8 | 4 | 1 | 1 | 1 | 4 | save_all | 20 | 64 | 8.1 | 16.2 | 8.1 | 0.0 | 34.4 | 13.4 | 80.2 | 204.6 | 0 | 1.39 | 4.59 | 0.00 | 0.00 | 0.26 | 0.41 | 0.00 | 0.00 | 0.02 | 6.66 | 20.8 | TP AR |
| 41 | 2 | 1 | 1 | 1 | 8 | 8 | dot-mlp+ctx | 80 | 8 | 4.1 | 8.1 | 4.1 | 3.7 | 58.3 | 15.6 | 93.9 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.22 | 0.00 | 0.47 | 0.00 | 0.00 | 0.01 | 6.67 | 20.8 | TP AR |
| 42 | 8 | 4 | 1 | 1 | 1 | 4 | min+ctx | 32 | 40 | 8.1 | 16.2 | 8.1 | 0.0 | 47.4 | 16.0 | 95.8 | 204.6 | 7 | 1.48 | 4.59 | 0.00 | 0.00 | 0.26 | 0.41 | 0.00 | 0.00 | 0.02 | 6.76 | 20.5 | TP AR |
| 43 | 2 | 2 | 1 | 1 | 8 | 4 | dot-mlp+ctx | 40 | 8 | 4.1 | 8.1 | 4.1 | 3.7 | 58.3 | 15.6 | 93.9 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.22 | 0.18 | 0.41 | 0.00 | 0.00 | 0.01 | 6.79 | 20.4 | TP AR |
| 44 | 8 | 2 | 1 | 1 | 1 | 8 | save_all | 40 | 64 | 8.1 | 16.2 | 8.1 | 0.0 | 33.8 | 13.2 | 79.5 | 204.6 | 0 | 1.39 | 4.59 | 0.00 | 0.00 | 0.35 | 0.47 | 0.00 | 0.00 | 0.02 | 6.81 | 20.3 | TP AR |
| 45 | 2 | 2 | 1 | 1 | 16 | 2 | dot-mlp+ctx | 20 | 8 | 2.0 | 4.1 | 2.0 | 3.9 | 58.3 | 14.1 | 84.5 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.48 | 0.09 | 0.32 | 0.00 | 0.00 | 0.00 | 6.87 | 20.2 | FSDP |
| 46 | 2 | 1 | 1 | 1 | 16 | 4 | dot-mlp+ctx | 40 | 8 | 2.0 | 4.1 | 2.0 | 3.9 | 58.3 | 14.1 | 84.5 | 204.6 | 23 | 1.71 | 2.27 | 0.00 | 2.48 | 0.00 | 0.41 | 0.00 | 0.00 | 0.00 | 6.87 | 20.2 | FSDP |
| 47 | 2 | 1 | 1 | 1 | 32 | 2 | dot-mlpwi | 20 | 8 | 1.0 | 2.0 | 1.0 | 4.1 | 70.8 | 15.8 | 94.7 | 204.6 | 22 | 1.68 | 2.27 | 0.00 | 2.63 | 0.00 | 0.32 | 0.00 | 0.00 | 0.00 | 6.90 | 20.1 | FSDP |
| 48 | 8 | 2 | 1 | 1 | 1 | 8 | min+ctx | 64 | 40 | 8.1 | 16.2 | 8.1 | 0.0 | 47.4 | 16.0 | 95.8 | 204.6 | 7 | 1.48 | 4.59 | 0.00 | 0.00 | 0.35 | 0.47 | 0.00 | 0.00 | 0.02 | 6.91 | 20.1 | TP AR |
Profiling
通用配置
- gbs=5120
- v7x 64 chips
GPU vs TPU MFU 对比
GPU MFU 指标(1280 张 H200 一天处理 1133 B token)
| 指标 | 值 |
|---|---|
| 集群峰值 | 1,265.9 PFLOPS |
| tokens/day | 1,133B |
| tokens/sec | 13.11M |
| tokens/GPU/sec | 10,245 |
| Training FLOPs/sample | 39.96 TFLOPs |
| 每 GPU 有效算力 | 100.5 TFLOPS |
| GPU 峰值 (BF16) | 989 TFLOPS |
| MFU | 10.2% |
TPU v7x MFU 指标(64 chips, 128 cores, DP=4, FSDP=32)
| 指标 | 值 |
|---|---|
| 集群峰值 | 147.6 PFLOPS |
| step time | 12s |
| tokens/step | 20,971,520 |
| tokens/sec | 1.75M |
| Training FLOPs/sample | 39.96 TFLOPs |
| 端到端计算量/chip | 3,196.8 TFLOPs |
| 每 chip 有效算力 | 266.4 TFLOPS |
| chip 峰值 (BF16) | 2,307 TFLOPS |
| MFU | 11.5% |
待实现功能
| 功能 | 状态 | 说明 |
|---|---|---|
| Profiling CI | ✅ 已完成 | |
| scan_layers=true | ✅ 已完成 | |
| GLA kernel mbs=4 VMEM OOM | ✅ 已修复 | |
| 理论计算分析工具矫正 | 进行中 | |
| 长序列验证/性能优化 | 进行中 | |
| Grain DataLoader get_next_batch 耗时过长 | ✅ 已修复 | 原 > 600ms |
| PP 适配 al_model hybrid 架构 | 待评估 | |
| Kernel Profiling CI workflow | ✅ 已完成 | |
| GLA/Megablox 算子 Profiling | ✅ 已完成 | 见下方 Kernel Profiling 结果 |
Profiling 分析
端到端并行/重计算策略(gbs=5120, 64 chips)
MFU 由 step time 推导:MFU = model_FLOPs / (step_time × num_devices × peak_FLOPS),基线 run68 校准为 11.89%。
| # | CI Run | DP | FSDP | EP | Remat | pdb×ga | Step (s) | MFU | 说明 |
|---|---|---|---|---|---|---|---|---|---|
| 1 | ci-prof-run68 | 4 | 32 | 1 | save_out_proj | 10×4 | 11.775 | 11.89% | Canonical 最优(复现: run92, run93) |
| 2 | ci-prof-run115 | 4 | 32 | 1 | save_out_proj | 10×4 | 11.801 | 11.86% | +named_scope,性能无影响 |
| 3 | ci-prof-run100 | 4 | 32 | 1 | save_dot_ctx_ex_mlp | 8×5 | 12.124 | 11.55% | |
| 4 | ci-prof-run95 | 2 | 64 | 1 | save_out_proj | 10×4 | 12.377 | 11.31% | NaN@step3 |
| 5 | ci-prof-run99 | 8 | 16 | 1 | save_qkv_proj | 8×5 | 12.574 | 11.13% | |
| 6 | ci-prof-run94 | 1 | 128 | 1 | save_qkv_proj | 10×4 | 13.013 | 10.76% | |
| 7 | ci-prof-run101 | 1 | 128 | 1 | save_dot_ctx_ex_mlp | 8×5 | 13.163 | 10.64% | |
| 8 | ci-prof-run102 | 2 | 64 | 1 | save_dot_ctx_ex_mlp | 8×5 | 13.839 | 10.12% | |
| 9 | ci-prof-run79 | 1 | 128 | 1 | minimal | 2×20 | 16.873 | 8.30% | 旧版本代码 |
| 10 | ci-prof-run90 | 1 | 128 | 1 | minimal | 2×20 | 31.447 | 4.45% | 复现: run88 (32.15s) |
| 11 | ci-prof-run87 | 1 | 128 | 1 | minimal_with_ctx | 2×20 | 32.061 | 4.37% | |
| 12 | ci-prof-run111 | 1 | 16 | 8 | save_dot_ctx_ex_mlp | 8×5 | 61.461 | 2.28% | |
| 13 | ci-prof-run112 | 1 | 16 | 8 | save_dot_except_mlp | 10×4 | 61.820 | 2.26% | |
| 14 | ci-prof-run110 | 16 | 1 | 8 | minimal_with_ctx | 1×40 | 91.643 | 1.53% | NaN@step1 |
Step time 取稳态 steps (3-9) 中位数。CP=2/4/8 共 9 个配置因与 packing 不兼容而失败,未列入。
结论
- DP=4/FSDP=32/save_out_proj 确认最优,4 次独立运行一致(run68/92/93/115, ~11.8s)
- Remat 排名:save_out_proj > save_dot_ctx_ex_mlp > save_qkv_proj
- EP=1 有效配置间差异 <18%,MFU 天花板由内核效率决定
Canonical 基线 Step 时间分解(run115, XPlane + named_scope 分析)
数据来源:ci-prof-run115 XPlane (TPU:0),与 run68 同配置(DP=4, FSDP=32, save_out_proj, pdb=10×ga=4),额外启用
named_scope标注,可将算子归因到模型组件。Step time 11.68s,性能与 run68 (11.78s) 一致。named_scope 使 GLA 可独立归因。下表中 GLA 从 MatMul / Elementwise / Other 中提取为独立行;分析脚本:scripts/analyze_run115_final_v2.py。
| 类别 | 耗时 (ms) | 占比 | Op 数 | 平均 (ms) | 前向 / 反向 | 说明 |
|---|---|---|---|---|---|---|
| GLA(Pallas+Dense+Elemwise) | 2,868 | 24.6% | 136,836 | 0.021 | 729 / 2,139 | 最大单一热点;Pallas kernel 570ms · qkv+out matmul 1,062ms · elemwise 546ms · data format 417ms · 其余 273ms;bwd 100% 为 remat 重算 |
| GMM(MoE 前向+反向 dlhs) | 1,909 | 16.3% | 708 | 2.696 | 642 / 1,268 | 20 MoE 层 × 3 ops × ga=4;bwd 97% remat |
| Async Fusion Stall | 1,356 | 11.6% | 1,504 | 0.901 | 477 / 879 | 通信-计算 overlap 等待(XLA 自动管理) |
| MatMul(Dense,非 GLA) | 1,143 | 9.8% | 1,624 | 0.704 | 322 / 820 | MoE expert 41% · output proj 29% · logits 29%;bwd 42% remat |
| Elementwise(非 GLA) | 924 | 7.9% | 27,368 | 0.034 | 315 / 597 | 激活、norm、residual(GLA 部分已提取) |
| TGMM(MoE 反向 drhs) | 668 | 5.7% | 240 | 2.783 | — / 668 | 反向权重梯度(XLA 延迟调度);95% remat scope |
| Custom Fusion(MoE dispatch) | 668 | 5.7% | 5,004 | 0.133 | 327 / 186 | token permute/unpermute |
| Attention(MLA SplashAttn) | 628 | 5.4% | 76 | 8.267 | 149 / 479 | 仅 4 层 MLA;fwd 149ms · dKV 188ms · dQ 172ms;85% remat |
| All-Gather 通信 | 404 | 3.5% | 1,108 | 0.365 | 29 / 375 | FSDP weight gather(GLA all-gather 38ms 已提取) |
| All-Reduce 通信 | 393 | 3.4% | 392 | 1.001 | — / 393 | 梯度同步 |
| Routing/Sort | 363 | 3.1% | 948 | 0.383 | 206 / 157 | top-k 路由 |
| Data Formatting(非 GLA) | 158 | 1.3% | 5,900 | 0.027 | 26 / 48 | XLA layout 转换(GLA 部分 417ms 已提取) |
| Other(AllToAll、ReduceScatter 等) | 198 | 1.7% | 2,301 | 0.086 | — | AllToAll 87ms · ReduceScatter 55ms · 杂项 56ms |
ConcatBitcast (DMA offload stall):2,921ms 的 DMA 等待时间被 XLA 与上述计算完全 overlap(去除后剩余算子总耗时 11,680ms ≈ step time),不占用额外 wall-clock 时间。
四类归因:
| 归因 | 耗时 (ms) | 占比 | 含义 |
|---|---|---|---|
| GLA 计算 | 2,868 | 24.6% | GLA Pallas kernel + 投影 + elemwise + 格式转换 |
| MoE 专有计算 | 3,608 | 30.9% | GMM 1,909 + TGMM 668 + dispatch 668 + routing 363 |
| 高效计算(非 GLA) | 1,771 | 15.2% | MatMul 1,143 + Attention 628 |
| 非计算开销 | 3,434 | 29.4% | comm 995 + stall 1,356 + elemwise 924 + format 158 |
Remat 开销:save_out_proj 策略下,反向传播中 89% 的算子时间用于前向重算(remat)。GLA 反向 100% 为重算,GMM 反向 97% 为重算,Attention 反向 85% 为重算。总 remat 耗时约 7,285ms,占 step time 的 62.4%。
named_scope 组件归因:MoE 层 28.2% · GLA 层 24.6% · MTP block 4.1% · MLA attention 2.2% · Output projection 2.5% · 其余 38.5%(Async stall/comm/elemwise 等无明确组件归属)
Kernel Profiling 结果
GLA(前向+反向,每步 20.4ms,ci-op-prof-run17/18, pdb=10)
Pallas chunk_simple_gla 内核,含 chunk_fwd_h + chunk_bwd_dh + chunk_simple_gla_bwd_o_pl。
| 算子 | 耗时 (ms) | 占比 | 说明 |
|---|---|---|---|
| transpose_jvp(反向) | 9.50 | 46.6% | Pallas 反向内核,含 fori_loop 控制流 |
| while.646(前向 scan) | 1.93 | 9.5% | Pallas chunk scan 循环 |
| convolution fusion(matmul) | 3.82 | 18.7% | AI=76.8,远低于 ridge=313 |
| data formatting(copy) | 2.59 | 12.7% | HBM→VMEM layout 转换 |
硬件单元利用率(GLA 反向,9.5ms 窗口):
| 硬件单元 | 平均 % | 最高 % | 说明 |
|---|---|---|---|
| Scalar ALU | 34.0% | 54.5% | 主导 — Pallas kernel 内 fori_loop/cond 控制流 |
| MXU | 11.6% | 25.2% | 矩阵乘法单元利用率低 |
| Vector ALU | 7.6% | 16.5% | |
| Vector Load | 8.8% | 12.1% | |
| Vector Store | 3.3% | 7.1% |
瓶颈分析:Scalar ALU 是 MXU 的 3×,说明 Pallas backward kernel 的
fori_loop/lax.cond控制流产生大量标量运算,MXU 大部分时间空闲。反向耗时是前向的 4.9×。
Megablox GMM/TGMM(前向+反向,每步 11.8ms,ci-op-prof-run17, pdb=10)
| 算子 | 耗时 (ms) | 占比 | 说明 |
|---|---|---|---|
| transpose_jvp_tgmm(×2) | 5.54 | 46.9% | TGMM 反向 |
| transpose_jvp_gmm(×2) | 5.14 | 43.6% | GMM 反向 |
| broadcast(初始化) | 1.00 | 8.5% | 输出缓冲区清零 |
| 前向 GMM | 不可见 | — | 被 XLA 内联 |
硬件单元利用率(GMM/TGMM 反向):
| 硬件单元 | TGMM 反向 | GMM 反向 | 说明 |
|---|---|---|---|
| MXU | mean 33.4%, max 84.1% | mean 39.9%, max 85.5% | |
| Scalar ALU | mean 39.5%, max 87.8% | mean 36.5%, max 85.2% | 与 MXU 接近 |
| Vector Store | mean 25.4%, max 96.4% | mean 21.4%, max 97.0% | 写入带宽饱和 |
| Vector Load | mean 12.5%, max 83.9% | mean 10.2%, max 32.6% |
瓶颈分析:MXU (33-40%) 和 Scalar ALU (36-40%) 接近,说明内核在做矩阵运算但效率受限。Vector Store max 97% 表明 VMEM→HBM 写回是关键瓶颈。Per-expert AI=310 接近 ridge point=313,问题主要在内核 tiling/pipeline 实现层面。
GMM Tiling 实验
| Tiling | pdb=2 | pdb=10 | vs 默认 |
|---|---|---|---|
| 512,1024,1024(默认) | 4.2ms | 13ms | — |
| 128,256,128 | 41ms | — | +876% |
| 512,512,512 | 5ms | 15ms | +15-19% |
默认 tiling 最优,小 tile 带来 64× 更多 grid 迭代,overhead 远超计算节省。
Roofline 定位
| 算子 | AI (FLOPs/byte) | vs Ridge (313) | Roofline 上限 | 实际效率 | Bound |
|---|---|---|---|---|---|
| Dense qkv_proj | 1,296 | 4.1× | ~100% | ~100% | Compute |
| GMM wi_0 (pdb=10) | 178 | 0.57× | 57% | ~11% | Memory |
| TGMM wi_0 (pdb=10) | ~150 | 0.48× | 48% | ~7.5% | Memory |
| GLA intra_qk | 25.6 | 0.08× | 8% | ~0.4% | Memory |
MoE 256 experts × top-8 将大矩阵拆成 256 个小 matmul,每个 expert 仅 ~50 tokens(pdb=10),AI 低于 ridge point → HBM 带宽受限。
优化项与预估提升空间
框架层优化
| 优化项 | 类别 | 预估节省 | 原理 | 优先级 | 状态 |
|---|---|---|---|---|---|
| Data Loader 异步预取 | DataLoader | — | host→device 数据传输异步化 | — | ✅ 已实现 |
| Data Formatting 消除 | XLA Layout | 10-20ms/step | 574ms 中部分 copy 可通过 layout 提示避免 | P2 | 待分析 |
| XLA SparseCore Flags 安全子集 | XLA 编译器 | 10-45ms/step | 17 flags 中隔离不导致 OOM 的子集 | P1 | 整包 OOM,需逐个测试 |
并行策略优化
| 优化项 | 类别 | 预估节省 | 原理 | 优先级 | 状态 |
|---|---|---|---|---|---|
| Pipeline Parallelism (PP) | 并行 | 待评估 | 异构层(GLA/MLA)天然不适合 PP 分组 | P2 | 需 ALPipelineStage 支持 |
| scan_layers | 并行 | 减少编译时间 | 支持 PP 的前提;异构层需 ALPipelineStage | — | ✅ 已实现 |
算子级优化
| 优化项 | 类别 | 预估节省 | 原理 | 优先级 | 状态 |
|---|---|---|---|---|---|
| GLA 反向 kernel 优化 | GLA | 减少 6-8ms/层 | 反向 9.5ms vs 前向 1.9ms(4.9×),Scalar ALU 34% 主导 | P0 | 被 JAX dynamic_slice 阻塞 |
| GLA | — | ❌ 已放弃:会影响精度 | |||
| MoE Sort+Gather 融合 | MoE | 20-40ms/step | routing/sort=363ms,重复 HBM 访问 | P1 | 待实现 |
| GMM/TGMM Vector Store 优化 | MoE | 待评估 | Vector Store max 97% 饱和,需改善写回 pipeline | P2 | 需内核层面优化 |
Profiling 数据:gs://ant-pretrain/pretrain/profiling/ 和 gs://ant-pretrain/pretrain/operator-profiling/