Skip to content

理论分析结果

ALModel Parallelism Strategy Analysis

Model

PropertyValue
Parameters17.43B (expert: 16.11B, non-expert: 1.33B)
Layers21 total (16 GLA + 5 MLA, 20 MoE + 1 dense MLP)
Experts256 x FFN 512, top_k=8
Attention16h x 128d, EMB_DIM=2048

Training

PropertyValue
Global batch size5120
Sequence length4096
Forward FLOPs/sample13.33 TFLOPs
Backward FLOPs/sample26.63 TFLOPs
Training FLOPs/sample39.96 TFLOPs
Training FLOPs/step204.57 PFLOPs
Remat overhead (full)+33.4%

Hardware

PropertyValue
Devices128 x 1154 TFLOPs/device = 147648 TFLOPs total
HBM96 GiB/device
ICI bandwidth1200 GB/s bidirectional (nominal)
Measured BW (GB/s)AR={2: 23.1, 4: 46.3, 8: 80.1}, AG={2: 34.3, 4: 89.9, 8: 186.3}, RS={2: 46.0, 4: 92.5, 8: 185.3}, A2A=
PP ppermute600 GB/s (overlapped)
FSDP overlap85%
XLA reserve30%
Total memory (no parallelism)W=64.9GB + O=129.9GB + G=64.9GB = 259.8GB

Configs with MFU >= 20% (sorted by step time)

python
  REMAT_LABELS = {
      "save_all":                         "save_all",
      "minimal_with_context":             "min+ctx",
      "minimal":                          "minimal",
      "save_dot_with_context_except_mlp": "dot-mlp+ctx",
      "save_dot_except_mlpwi":            "dot-mlpwi",
      "save_dot_except_mlp":              "dot-mlp",
      "save_qkv_proj":                    "qkv_proj",
      "save_out_proj":                    "out_proj",
      "full":                             "full",
  }
#TPDPPPEPFSDPCPRematPDBMBGAW(GB)O(GB)G(GB)FBufAct(GB)RsvTot(GB)Trn(PF)Rmt%CompTP(s)EP(s)FSDP+DP(s)CP(s)PP+(s)BubOptStep(s)MFU%Bottleneck
11811161full101044.18.14.17.951.218.894.2204.6331.850.000.003.990.040.000.000.000.015.8923.5FSDP
21411162full102044.18.14.17.951.218.894.2204.6331.850.000.003.990.030.130.000.000.016.0123.0FSDP
31411321out_proj101042.04.12.08.154.517.788.5204.6321.820.000.004.190.020.000.000.000.006.0422.9FSDP
41211641qkv_proj101041.02.01.08.361.818.592.7204.6291.780.000.004.320.010.000.000.000.006.1122.7FSDP
511111281dot-mlp101040.51.00.58.365.118.994.3204.6271.760.000.004.380.000.000.000.000.006.1422.6FSDP
61211322out_proj102042.04.12.08.154.517.788.5204.6321.820.000.004.190.010.130.000.000.006.1622.5FSDP
71211164full104044.18.14.17.951.218.894.2204.6331.850.000.003.990.020.340.000.000.016.2122.3FSDP
81111642qkv_proj102041.02.01.08.361.818.592.7204.6291.780.000.004.320.000.130.000.000.006.2422.2FSDP
91111324out_proj104042.04.12.08.154.517.788.5204.6321.820.000.004.190.000.340.000.000.006.3621.8FSDP
101111168full108044.18.14.17.951.218.894.2204.6331.850.000.003.990.000.760.000.000.016.6021.0FSDP

Column Legend

ColumnDescription
TP / DP / PP / EP / FSDP / CPTensor / Data / Pipeline / Expert / Fully-Sharded-Data / Context parallelism degree
RematRematerialization (activation checkpointing) policy
PDBper_device_batch_size, the MaxText config value for batch size per device
MBActual physical per-device batch = PDB × TP × CP (batch not sharded by tensor/context axes)
GAGradient accumulation steps
W(GB)Model weights per device (GB)
O(GB)Optimizer states per device (GB), including Adam mu + nu
G(GB)Gradients per device (GB)
FBufFSDP all-gather prefetch buffer (GB), peak = 2 layers x full_layer_weight x (FSDP-1)/FSDP
Act(GB)Activation memory per device (GB), depends on micro batch and remat policy
RsvXLA overhead reserve (GB), 30% of modeled memory for comm buffers & HLO temps
Tot(GB)Total HBM usage per device (GB) = W + O + G + FBuf + Act + Rsv
Trn(PF)Training FLOPs per step (PetaFLOPs), useful fwd+bwd without remat overhead
Rmt%Remat overhead percentage, extra compute from reactivation recomputation
CompCompute time (s), actual FLOPs / (num_devices x peak TFLOPs/device)
TP(s)TP all-reduce communication time (s), non-overlappable
EP(s)EP all-to-all communication time (s), non-overlappable
FSDP+FSDP exposed communication time (s), partially overlapped with compute (efficiency=85%)
DP(s)DP gradient all-reduce time (s), non-overlappable
CP(s)CP KV all-gather communication time (s), non-overlappable
PP+(s)PP ppermute communication time (s), fully overlapped by XLA scheduler
Bub(s)PP bubble idle time (s), formula: (PP-1)/(num_repeats x GA + PP-1)
OptOptimizer step time (s), memory-bound AdamW (28B/param, HBM_BW=3690GB/s)
Step(s)Total step time (s) = Comp + all comm + Bub + Opt
MFU%Model FLOPs Utilization (%), useful TFLOPs / (step_time x num_devices x peak TFLOPs)
BottleneckThe dominant time component limiting throughput

#1: TP=1 DP=4 PP=1 EP=1 FSDP=32 CP=1 remat=save_out_proj

Memory per device

ComponentSize (GB)Detail
Weights2.03expert: 1.88 + non-expert: 0.15
Optimizer4.06
Gradients2.03
FSDP buffer8.14all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP
Activations54.53micro_batch=10 x inflight=1, remat=save_out_proj
XLA reserve21.2430% of modeled
Total92.02/ 96 GiB HBM (96% used)

Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40

FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%

Communication

TypeVolume (GB)Time (s)Note
TP all-reduce0.00.00non-overlappable
EP all-to-all0.00.00non-overlappable
FSDP gather+RS754.94.06 total, 2.51 exposedoverlap=85%
DP all-reduce3.00.07non-overlappable
CP all-gather (KV)0.00.00non-overlappable
PP ppermute (ICI)0.00.00fully overlapped
Optimizer (AdamW)14.20.00memory-bound, 28B/param, HBM=3690GB/s

Performance

  • Step time: 4.402s (compute=1.823 + comm=2.576 + optimizer=0.004)

  • MFU: 31.5%

  • Bottleneck: FSDP

#2: TP=1 DP=1 PP=1 EP=1 FSDP=128 CP=1 remat=save_qkv_proj

Memory per device

ComponentSize (GB)Detail
Weights0.51expert: 0.47 + non-expert: 0.04
Optimizer1.01
Gradients0.51
FSDP buffer8.33all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP
Activations61.80micro_batch=10 x inflight=1, remat=save_qkv_proj
XLA reserve21.6530% of modeled
Total93.81/ 96 GiB HBM (98% used)

Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40

FLOPs: useful=204.6 PFLOPs, actual=263.4 PFLOPs, remat=+29%

Communication

TypeVolume (GB)Time (s)Note
TP all-reduce0.00.00non-overlappable
EP all-to-all0.00.00non-overlappable
FSDP gather+RS773.24.16 total, 2.64 exposedoverlap=85%
DP all-reduce0.00.00non-overlappable
CP all-gather (KV)0.00.00non-overlappable
PP ppermute (ICI)0.00.00fully overlapped
Optimizer (AdamW)3.60.00memory-bound, 28B/param, HBM=3690GB/s

Performance

  • Step time: 4.426s (compute=1.784 + comm=2.641 + optimizer=0.001)

  • MFU: 31.3%

  • Bottleneck: FSDP

#3: TP=1 DP=2 PP=1 EP=1 FSDP=64 CP=1 remat=save_out_proj

Memory per device

ComponentSize (GB)Detail
Weights1.01expert: 0.94 + non-expert: 0.08
Optimizer2.03
Gradients1.01
FSDP buffer8.27all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP
Activations54.53micro_batch=10 x inflight=1, remat=save_out_proj
XLA reserve20.0630% of modeled
Total86.91/ 96 GiB HBM (91% used)

Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40

FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%

Communication

TypeVolume (GB)Time (s)Note
TP all-reduce0.00.00non-overlappable
EP all-to-all0.00.00non-overlappable
FSDP gather+RS767.14.12 total, 2.58 exposedoverlap=85%
DP all-reduce1.00.04non-overlappable
CP all-gather (KV)0.00.00non-overlappable
PP ppermute (ICI)0.00.00fully overlapped
Optimizer (AdamW)7.10.00memory-bound, 28B/param, HBM=3690GB/s

Performance

  • Step time: 4.444s (compute=1.823 + comm=2.620 + optimizer=0.002)

  • MFU: 31.2%

  • Bottleneck: FSDP

#4: TP=1 DP=1 PP=1 EP=1 FSDP=64 CP=2 remat=save_out_proj

Memory per device

ComponentSize (GB)Detail
Weights1.01expert: 0.94 + non-expert: 0.08
Optimizer2.03
Gradients1.01
FSDP buffer8.27all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP
Activations54.53micro_batch=20 x inflight=1, remat=save_out_proj
XLA reserve20.0630% of modeled
Total86.91/ 96 GiB HBM (91% used)

Batch: per_device_batch=10, micro_batch=20 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=80

FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%

Communication

TypeVolume (GB)Time (s)Note
TP all-reduce0.00.00non-overlappable
EP all-to-all0.00.00non-overlappable
FSDP gather+RS767.14.12 total, 2.58 exposedoverlap=85%
DP all-reduce0.00.00non-overlappable
CP all-gather (KV)12.50.32non-overlappable
PP ppermute (ICI)0.00.00fully overlapped
Optimizer (AdamW)7.10.00memory-bound, 28B/param, HBM=3690GB/s

Performance

  • Step time: 4.718s (compute=1.823 + comm=2.894 + optimizer=0.002)

  • MFU: 29.4%

  • Bottleneck: FSDP

#5: TP=1 DP=2 PP=1 EP=1 FSDP=32 CP=2 remat=save_out_proj

Memory per device

ComponentSize (GB)Detail
Weights2.03expert: 1.88 + non-expert: 0.15
Optimizer4.06
Gradients2.03
FSDP buffer8.14all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP
Activations54.53micro_batch=20 x inflight=1, remat=save_out_proj
XLA reserve21.2430% of modeled
Total92.02/ 96 GiB HBM (96% used)

Batch: per_device_batch=10, micro_batch=20 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=80

FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%

Communication

TypeVolume (GB)Time (s)Note
TP all-reduce0.00.00non-overlappable
EP all-to-all0.00.00non-overlappable
FSDP gather+RS754.94.06 total, 2.51 exposedoverlap=85%
DP all-reduce2.00.09non-overlappable
CP all-gather (KV)12.50.32non-overlappable
PP ppermute (ICI)0.00.00fully overlapped
Optimizer (AdamW)14.20.00memory-bound, 28B/param, HBM=3690GB/s

Performance

  • Step time: 4.743s (compute=1.823 + comm=2.916 + optimizer=0.004)

  • MFU: 29.2%

  • Bottleneck: FSDP

Worst 10 configurations (slowest)

10 of 10 shown

#TPDPPPEPFSDPCPRematPDBMBGAW(GB)O(GB)G(GB)FBufAct(GB)RsvTot(GB)Trn(PF)Rmt%CompTP(s)EP(s)FSDP+DP(s)CP(s)PP+(s)BubOptStep(s)MFU%Bottleneck
1121248save_all18408.717.58.74.019.317.575.8204.601.390.004.6833.470.380.470.000.000.0240.413.4FSDP
2141244save_all14408.717.58.74.020.217.877.0204.601.390.004.6833.470.280.410.000.000.0240.253.4FSDP
3181242save_all12408.717.58.74.022.118.379.4204.601.390.004.6833.470.190.320.000.000.0240.073.5FSDP
41161241save_all11408.717.58.74.025.819.584.3204.601.390.004.6833.470.200.000.000.000.0239.763.5FSDP
51811161save_all11404.18.14.17.925.815.064.9204.601.390.000.0038.110.090.000.000.000.0139.593.5FSDP
6121188save_all18408.116.28.17.319.317.776.8204.601.390.000.0035.490.350.470.000.000.0237.713.7FSDP
7141184save_all14408.116.28.17.320.218.078.1204.601.390.000.0035.490.260.410.000.000.0237.563.7FSDP
8121428save_all184010.019.910.01.919.318.379.4204.601.390.004.5030.740.430.470.000.000.0237.553.7FSDP
9181182save_all12408.116.28.17.322.118.680.5204.601.390.000.0035.490.180.320.000.000.0237.383.7FSDP
10141424save_all144010.019.910.01.920.218.680.7204.601.390.004.5030.740.320.410.000.000.0237.383.7FSDP

Summary

  • Best config: TP=1 DP=4 PP=1 EP=1 FSDP=32 CP=1 remat=save_out_proj

  • Best step time: 4.402s, MFU: 31.5%

  • MFU range: 3.4% - 31.5%

  • Step time range: 4.402s - 40.406s