Skip to content

性能优化工作拆解

ALModel 17B XL 在 TPU v7x 上的训练性能分析与优化工作全局视图。

理论分析

理论 MFU > 20% 的并行/重计算策略

下表为早期分析版本(XLA reserve=20%),完整最新版本见理论分析结果(XLA reserve=30%)。

#TPDPPPEPFSDPCPRematMBGAW(GB)O(GB)G(GB)FBufAct(GB)RsvTot(GB)Trn(PF)Rmt%CompTPEPFSDP+DPCPPP+BubOptStep(s)MFU%Bottleneck
11811161out_proj1044.18.14.17.954.515.794.4204.6321.820.000.002.380.090.000.000.000.014.3032.2FSDP
21411321qkv_proj1042.04.12.08.161.815.693.7204.6291.780.000.002.540.070.000.000.000.004.4031.5FSDP
311111281dot-mlp+ctx1040.51.00.58.368.415.794.5204.6231.710.000.002.710.000.000.000.000.004.4131.4FSDP
41211641dot-mlp1041.02.01.08.365.115.592.9204.6271.760.000.002.630.040.000.000.000.004.4331.2FSDP
51411162out_proj2044.18.14.17.954.515.794.4204.6321.820.000.002.380.130.320.000.000.014.6629.7FSDP
61111168out_proj8044.18.14.17.954.515.794.4204.6321.820.000.002.380.000.470.000.000.014.6829.6FSDP
71111642dot-mlp2041.02.01.08.365.115.592.9204.6271.760.000.002.630.000.320.000.000.004.7129.4FSDP
81211322qkv_proj2042.04.12.08.161.815.693.7204.6291.780.000.002.540.090.320.000.000.004.7429.3FSDP
91111324qkv_proj4042.04.12.08.161.815.693.7204.6291.780.000.002.540.000.410.000.000.004.7429.2FSDP
101211164out_proj4044.18.14.17.954.515.794.4204.6321.820.000.002.380.180.410.000.000.014.8028.9FSDP
114321111save_all116016.232.516.20.010.215.090.1204.601.393.400.000.000.390.000.000.000.035.2126.6TP AR
121811161dot-mlp+ctx854.18.14.17.954.715.894.6204.6231.710.000.003.460.090.000.000.000.015.2626.3FSDP
134321111min+ctx28016.232.516.20.014.015.894.8204.671.483.400.000.000.390.000.000.000.035.3126.1TP AR
141411321dot-mlp+ctx852.04.12.08.154.714.285.1204.6231.710.000.003.620.070.000.000.000.005.4025.7FSDP
151211641dot-mlp+ctx851.02.01.08.354.713.480.4204.6231.710.000.003.710.040.000.000.000.005.4625.4FSDP
164161112save_all216016.232.516.20.09.214.889.0204.601.393.400.000.000.380.320.000.000.035.5225.1TP AR
17481114save_all512816.232.516.20.011.015.291.1204.601.393.400.000.000.350.410.000.000.035.5824.8TP AR
184161112min+ctx48016.232.516.20.014.015.894.8204.671.483.400.000.000.380.320.000.000.035.6124.7TP AR
191411162dot-mlp+ctx1654.18.14.17.954.715.894.6204.6231.710.000.003.460.130.320.000.000.015.6224.6FSDP
201111168dot-mlp+ctx6454.18.14.17.954.715.894.6204.6231.710.000.003.460.000.470.000.000.015.6524.5FSDP
21481114min+ctx88016.232.516.20.014.015.894.8204.671.483.400.000.000.350.410.000.000.035.6824.4TP AR
221111642dot-mlp+ctx1651.02.01.08.354.713.480.4204.6231.710.000.003.710.000.320.000.000.005.7324.2FSDP
231211322dot-mlp+ctx1652.04.12.08.154.714.285.1204.6231.710.000.003.620.090.320.000.000.005.7424.1FSDP
241111324dot-mlp+ctx3252.04.12.08.154.714.285.1204.6231.710.000.003.620.000.410.000.000.005.7524.1FSDP
251211164dot-mlp+ctx3254.18.14.17.954.715.894.6204.6231.710.000.003.460.180.410.000.000.015.7624.0FSDP
26441118save_all1012816.232.516.20.010.715.190.7204.601.393.400.000.000.530.470.000.000.035.8123.8TP AR
274161112dot-mlpwi56416.232.516.20.014.916.095.8204.6221.683.400.000.000.380.320.000.000.035.8223.8TP AR
28481114dot-mlpwi106416.232.516.20.014.916.095.8204.6221.683.400.000.000.350.410.000.000.035.8823.6TP AR
29441118min+ctx168016.232.516.20.014.015.894.8204.671.483.400.000.000.530.470.000.000.035.9123.4TP AR
30441118dot-mlpwi206416.232.516.20.014.916.095.8204.6221.683.400.000.000.530.470.000.000.036.1122.7TP AR
318161111save_all5648.116.28.10.037.914.184.4204.601.394.590.000.000.190.000.000.000.026.1822.4TP AR
328161111min+ctx8408.116.28.10.047.416.095.8204.671.484.590.000.000.190.000.000.000.026.2822.1TP AR
33281181dot-mlp+ctx1084.18.14.13.758.315.693.9204.6231.712.270.002.220.090.000.000.000.016.2922.0TP AR
34881112save_all10648.116.28.10.035.513.681.6204.601.394.590.000.000.180.320.000.000.026.4821.4TP AR
352411161dot-mlp+ctx1082.04.12.03.958.314.184.5204.6231.712.270.002.480.070.000.000.000.006.5321.2FSDP
36881112min+ctx16408.116.28.10.047.416.095.8204.671.484.590.000.000.180.320.000.000.026.5821.1TP AR
372211321dot-mlpwi1081.02.01.04.170.815.894.7204.6221.682.270.002.630.040.000.000.000.006.6320.9FSDP
382111641dot-mlpwi1080.51.00.54.170.815.492.4204.6221.682.270.002.690.000.000.000.000.006.6520.8FSDP
39241182dot-mlp+ctx2084.18.14.13.758.315.693.9204.6231.712.270.002.220.130.320.000.000.016.6520.8TP AR
40841114save_all20648.116.28.10.034.413.480.2204.601.394.590.000.000.260.410.000.000.026.6620.8TP AR
41211188dot-mlp+ctx8084.18.14.13.758.315.693.9204.6231.712.270.002.220.000.470.000.000.016.6720.8TP AR
42841114min+ctx32408.116.28.10.047.416.095.8204.671.484.590.000.000.260.410.000.000.026.7620.5TP AR
43221184dot-mlp+ctx4084.18.14.13.758.315.693.9204.6231.712.270.002.220.180.410.000.000.016.7920.4TP AR
44821118save_all40648.116.28.10.033.813.279.5204.601.394.590.000.000.350.470.000.000.026.8120.3TP AR
452211162dot-mlp+ctx2082.04.12.03.958.314.184.5204.6231.712.270.002.480.090.320.000.000.006.8720.2FSDP
462111164dot-mlp+ctx4082.04.12.03.958.314.184.5204.6231.712.270.002.480.000.410.000.000.006.8720.2FSDP
472111322dot-mlpwi2081.02.01.04.170.815.894.7204.6221.682.270.002.630.000.320.000.000.006.9020.1FSDP
48821118min+ctx64408.116.28.10.047.416.095.8204.671.484.590.000.000.350.470.000.000.026.9120.1TP AR

Profiling

通用配置

  • gbs=5120
  • v7x 64 chips

GPU vs TPU MFU 对比

GPU MFU 指标(1280 张 H200 一天处理 1133 B token)

指标
集群峰值1,265.9 PFLOPS
tokens/day1,133B
tokens/sec13.11M
tokens/GPU/sec10,245
Training FLOPs/sample39.96 TFLOPs
每 GPU 有效算力100.5 TFLOPS
GPU 峰值 (BF16)989 TFLOPS
MFU10.2%

TPU v7x MFU 指标(64 chips, 128 cores, DP=4, FSDP=32)

指标
集群峰值147.6 PFLOPS
step time12s
tokens/step20,971,520
tokens/sec1.75M
Training FLOPs/sample39.96 TFLOPs
端到端计算量/chip3,196.8 TFLOPs
每 chip 有效算力266.4 TFLOPS
chip 峰值 (BF16)2,307 TFLOPS
MFU11.5%

待实现功能

功能状态说明
Profiling CI✅ 已完成
scan_layers=true✅ 已完成
GLA kernel mbs=4 VMEM OOM✅ 已修复
理论计算分析工具矫正进行中
长序列验证/性能优化进行中
Grain DataLoader get_next_batch 耗时过长✅ 已修复原 > 600ms
PP 适配 al_model hybrid 架构待评估
Kernel Profiling CI workflow✅ 已完成
GLA/Megablox 算子 Profiling✅ 已完成见下方 Kernel Profiling 结果

Profiling 分析

端到端并行/重计算策略(gbs=5120, 64 chips)

MFU 由 step time 推导:MFU = model_FLOPs / (step_time × num_devices × peak_FLOPS),基线 run68 校准为 11.89%。

#CI RunDPFSDPEPRematpdb×gaStep (s)MFU说明
1ci-prof-run684321save_out_proj10×411.77511.89%Canonical 最优(复现: run92, run93)
2ci-prof-run1154321save_out_proj10×411.80111.86%+named_scope,性能无影响
3ci-prof-run1004321save_dot_ctx_ex_mlp8×512.12411.55%
4ci-prof-run952641save_out_proj10×412.37711.31%NaN@step3
5ci-prof-run998161save_qkv_proj8×512.57411.13%
6ci-prof-run9411281save_qkv_proj10×413.01310.76%
7ci-prof-run10111281save_dot_ctx_ex_mlp8×513.16310.64%
8ci-prof-run1022641save_dot_ctx_ex_mlp8×513.83910.12%
9ci-prof-run7911281minimal2×2016.8738.30%旧版本代码
10ci-prof-run9011281minimal2×2031.4474.45%复现: run88 (32.15s)
11ci-prof-run8711281minimal_with_ctx2×2032.0614.37%
12ci-prof-run1111168save_dot_ctx_ex_mlp8×561.4612.28%
13ci-prof-run1121168save_dot_except_mlp10×461.8202.26%
14ci-prof-run1101618minimal_with_ctx1×4091.6431.53%NaN@step1

Step time 取稳态 steps (3-9) 中位数。CP=2/4/8 共 9 个配置因与 packing 不兼容而失败,未列入。

结论

  1. DP=4/FSDP=32/save_out_proj 确认最优,4 次独立运行一致(run68/92/93/115, ~11.8s)
  2. Remat 排名:save_out_proj > save_dot_ctx_ex_mlp > save_qkv_proj
  3. EP=1 有效配置间差异 <18%,MFU 天花板由内核效率决定

Canonical 基线 Step 时间分解(run115, XPlane + named_scope 分析)

数据来源:ci-prof-run115 XPlane (TPU:0),与 run68 同配置(DP=4, FSDP=32, save_out_proj, pdb=10×ga=4),额外启用 named_scope 标注,可将算子归因到模型组件。Step time 11.68s,性能与 run68 (11.78s) 一致。named_scope 使 GLA 可独立归因。下表中 GLA 从 MatMul / Elementwise / Other 中提取为独立行;分析脚本:scripts/analyze_run115_final_v2.py

类别耗时 (ms)占比Op 数平均 (ms)前向 / 反向说明
GLA(Pallas+Dense+Elemwise)2,86824.6%136,8360.021729 / 2,139最大单一热点;Pallas kernel 570ms · qkv+out matmul 1,062ms · elemwise 546ms · data format 417ms · 其余 273ms;bwd 100% 为 remat 重算
GMM(MoE 前向+反向 dlhs)1,90916.3%7082.696642 / 1,26820 MoE 层 × 3 ops × ga=4;bwd 97% remat
Async Fusion Stall1,35611.6%1,5040.901477 / 879通信-计算 overlap 等待(XLA 自动管理)
MatMul(Dense,非 GLA)1,1439.8%1,6240.704322 / 820MoE expert 41% · output proj 29% · logits 29%;bwd 42% remat
Elementwise(非 GLA)9247.9%27,3680.034315 / 597激活、norm、residual(GLA 部分已提取)
TGMM(MoE 反向 drhs)6685.7%2402.783— / 668反向权重梯度(XLA 延迟调度);95% remat scope
Custom Fusion(MoE dispatch)6685.7%5,0040.133327 / 186token permute/unpermute
Attention(MLA SplashAttn)6285.4%768.267149 / 479仅 4 层 MLA;fwd 149ms · dKV 188ms · dQ 172ms;85% remat
All-Gather 通信4043.5%1,1080.36529 / 375FSDP weight gather(GLA all-gather 38ms 已提取)
All-Reduce 通信3933.4%3921.001— / 393梯度同步
Routing/Sort3633.1%9480.383206 / 157top-k 路由
Data Formatting(非 GLA)1581.3%5,9000.02726 / 48XLA layout 转换(GLA 部分 417ms 已提取)
Other(AllToAll、ReduceScatter 等)1981.7%2,3010.086AllToAll 87ms · ReduceScatter 55ms · 杂项 56ms

ConcatBitcast (DMA offload stall):2,921ms 的 DMA 等待时间被 XLA 与上述计算完全 overlap(去除后剩余算子总耗时 11,680ms ≈ step time),不占用额外 wall-clock 时间。

四类归因

归因耗时 (ms)占比含义
GLA 计算2,86824.6%GLA Pallas kernel + 投影 + elemwise + 格式转换
MoE 专有计算3,60830.9%GMM 1,909 + TGMM 668 + dispatch 668 + routing 363
高效计算(非 GLA)1,77115.2%MatMul 1,143 + Attention 628
非计算开销3,43429.4%comm 995 + stall 1,356 + elemwise 924 + format 158

Remat 开销:save_out_proj 策略下,反向传播中 89% 的算子时间用于前向重算(remat)。GLA 反向 100% 为重算,GMM 反向 97% 为重算,Attention 反向 85% 为重算。总 remat 耗时约 7,285ms,占 step time 的 62.4%。

named_scope 组件归因:MoE 层 28.2% · GLA 层 24.6% · MTP block 4.1% · MLA attention 2.2% · Output projection 2.5% · 其余 38.5%(Async stall/comm/elemwise 等无明确组件归属)

Kernel Profiling 结果

GLA(前向+反向,每步 20.4ms,ci-op-prof-run17/18, pdb=10)

Pallas chunk_simple_gla 内核,含 chunk_fwd_h + chunk_bwd_dh + chunk_simple_gla_bwd_o_pl。

算子耗时 (ms)占比说明
transpose_jvp(反向)9.5046.6%Pallas 反向内核,含 fori_loop 控制流
while.646(前向 scan)1.939.5%Pallas chunk scan 循环
convolution fusion(matmul)3.8218.7%AI=76.8,远低于 ridge=313
data formatting(copy)2.5912.7%HBM→VMEM layout 转换

硬件单元利用率(GLA 反向,9.5ms 窗口)

硬件单元平均 %最高 %说明
Scalar ALU34.0%54.5%主导 — Pallas kernel 内 fori_loop/cond 控制流
MXU11.6%25.2%矩阵乘法单元利用率低
Vector ALU7.6%16.5%
Vector Load8.8%12.1%
Vector Store3.3%7.1%

瓶颈分析:Scalar ALU 是 MXU 的 3×,说明 Pallas backward kernel 的 fori_loop/lax.cond 控制流产生大量标量运算,MXU 大部分时间空闲。反向耗时是前向的 4.9×。

Megablox GMM/TGMM(前向+反向,每步 11.8ms,ci-op-prof-run17, pdb=10)

算子耗时 (ms)占比说明
transpose_jvp_tgmm(×2)5.5446.9%TGMM 反向
transpose_jvp_gmm(×2)5.1443.6%GMM 反向
broadcast(初始化)1.008.5%输出缓冲区清零
前向 GMM不可见被 XLA 内联

硬件单元利用率(GMM/TGMM 反向)

硬件单元TGMM 反向GMM 反向说明
MXUmean 33.4%, max 84.1%mean 39.9%, max 85.5%
Scalar ALUmean 39.5%, max 87.8%mean 36.5%, max 85.2%与 MXU 接近
Vector Storemean 25.4%, max 96.4%mean 21.4%, max 97.0%写入带宽饱和
Vector Loadmean 12.5%, max 83.9%mean 10.2%, max 32.6%

瓶颈分析:MXU (33-40%) 和 Scalar ALU (36-40%) 接近,说明内核在做矩阵运算但效率受限。Vector Store max 97% 表明 VMEM→HBM 写回是关键瓶颈。Per-expert AI=310 接近 ridge point=313,问题主要在内核 tiling/pipeline 实现层面。

GMM Tiling 实验

Tilingpdb=2pdb=10vs 默认
512,1024,1024(默认)4.2ms13ms
128,256,12841ms+876%
512,512,5125ms15ms+15-19%

默认 tiling 最优,小 tile 带来 64× 更多 grid 迭代,overhead 远超计算节省。

Roofline 定位

算子AI (FLOPs/byte)vs Ridge (313)Roofline 上限实际效率Bound
Dense qkv_proj1,2964.1×~100%~100%Compute
GMM wi_0 (pdb=10)1780.57×57%~11%Memory
TGMM wi_0 (pdb=10)~1500.48×48%~7.5%Memory
GLA intra_qk25.60.08×8%~0.4%Memory

MoE 256 experts × top-8 将大矩阵拆成 256 个小 matmul,每个 expert 仅 ~50 tokens(pdb=10),AI 低于 ridge point → HBM 带宽受限。

优化项与预估提升空间

框架层优化

优化项类别预估节省原理优先级状态
Data Loader 异步预取DataLoaderhost→device 数据传输异步化✅ 已实现
Data Formatting 消除XLA Layout10-20ms/step574ms 中部分 copy 可通过 layout 提示避免P2待分析
XLA SparseCore Flags 安全子集XLA 编译器10-45ms/step17 flags 中隔离不导致 OOM 的子集P1整包 OOM,需逐个测试

并行策略优化

优化项类别预估节省原理优先级状态
Pipeline Parallelism (PP)并行待评估异构层(GLA/MLA)天然不适合 PP 分组P2需 ALPipelineStage 支持
scan_layers并行减少编译时间支持 PP 的前提;异构层需 ALPipelineStage✅ 已实现

算子级优化

优化项类别预估节省原理优先级状态
GLA 反向 kernel 优化GLA减少 6-8ms/层反向 9.5ms vs 前向 1.9ms(4.9×),Scalar ALU 34% 主导P0被 JAX dynamic_slice 阻塞
GLA Chunk Size 64→128GLA5-10ms/stepAI 从 25→43,scan steps 减半❌ 已放弃:会影响精度
MoE Sort+Gather 融合MoE20-40ms/steprouting/sort=363ms,重复 HBM 访问P1待实现
GMM/TGMM Vector Store 优化MoE待评估Vector Store max 97% 饱和,需改善写回 pipelineP2需内核层面优化

Profiling 数据:gs://ant-pretrain/pretrain/profiling/gs://ant-pretrain/pretrain/operator-profiling/