理论分析结果
ALModel Parallelism Strategy Analysis
Model
| Property | Value |
|---|---|
| Parameters | 17.43B (expert: 16.11B, non-expert: 1.33B) |
| Layers | 21 total (16 GLA + 5 MLA, 20 MoE + 1 dense MLP) |
| Experts | 256 x FFN 512, top_k=8 |
| Attention | 16h x 128d, EMB_DIM=2048 |
Training
| Property | Value |
|---|---|
| Global batch size | 5120 |
| Sequence length | 4096 |
| Forward FLOPs/sample | 13.33 TFLOPs |
| Backward FLOPs/sample | 26.63 TFLOPs |
| Training FLOPs/sample | 39.96 TFLOPs |
| Training FLOPs/step | 204.57 PFLOPs |
| Remat overhead (full) | +33.4% |
Hardware
| Property | Value |
|---|---|
| Devices | 128 x 1154 TFLOPs/device = 147648 TFLOPs total |
| HBM | 96 GiB/device |
| ICI bandwidth | 1200 GB/s bidirectional (nominal) |
| Measured BW (GB/s) | AR={2: 23.1, 4: 46.3, 8: 80.1}, AG={2: 34.3, 4: 89.9, 8: 186.3}, RS={2: 46.0, 4: 92.5, 8: 185.3}, A2A= |
| PP ppermute | 600 GB/s (overlapped) |
| FSDP overlap | 85% |
| XLA reserve | 30% |
| Total memory (no parallelism) | W=64.9GB + O=129.9GB + G=64.9GB = 259.8GB |
Configs with MFU >= 20% (sorted by step time)
REMAT_LABELS = {
"save_all": "save_all",
"minimal_with_context": "min+ctx",
"minimal": "minimal",
"save_dot_with_context_except_mlp": "dot-mlp+ctx",
"save_dot_except_mlpwi": "dot-mlpwi",
"save_dot_except_mlp": "dot-mlp",
"save_qkv_proj": "qkv_proj",
"save_out_proj": "out_proj",
"full": "full",
}| # | TP | DP | PP | EP | FSDP | CP | Remat | PDB | MB | GA | W(GB) | O(GB) | G(GB) | FBuf | Act(GB) | Rsv | Tot(GB) | Trn(PF) | Rmt% | Comp | TP(s) | EP(s) | FSDP+ | DP(s) | CP(s) | PP+(s) | Bub | Opt | Step(s) | MFU% | Bottleneck |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 8 | 1 | 1 | 16 | 1 | full | 10 | 10 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 51.2 | 18.8 | 94.2 | 204.6 | 33 | 1.85 | 0.00 | 0.00 | 3.99 | 0.04 | 0.00 | 0.00 | 0.00 | 0.01 | 5.89 | 23.5 | FSDP |
| 2 | 1 | 4 | 1 | 1 | 16 | 2 | full | 10 | 20 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 51.2 | 18.8 | 94.2 | 204.6 | 33 | 1.85 | 0.00 | 0.00 | 3.99 | 0.03 | 0.13 | 0.00 | 0.00 | 0.01 | 6.01 | 23.0 | FSDP |
| 3 | 1 | 4 | 1 | 1 | 32 | 1 | out_proj | 10 | 10 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 54.5 | 17.7 | 88.5 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 4.19 | 0.02 | 0.00 | 0.00 | 0.00 | 0.00 | 6.04 | 22.9 | FSDP |
| 4 | 1 | 2 | 1 | 1 | 64 | 1 | qkv_proj | 10 | 10 | 4 | 1.0 | 2.0 | 1.0 | 8.3 | 61.8 | 18.5 | 92.7 | 204.6 | 29 | 1.78 | 0.00 | 0.00 | 4.32 | 0.01 | 0.00 | 0.00 | 0.00 | 0.00 | 6.11 | 22.7 | FSDP |
| 5 | 1 | 1 | 1 | 1 | 128 | 1 | dot-mlp | 10 | 10 | 4 | 0.5 | 1.0 | 0.5 | 8.3 | 65.1 | 18.9 | 94.3 | 204.6 | 27 | 1.76 | 0.00 | 0.00 | 4.38 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 6.14 | 22.6 | FSDP |
| 6 | 1 | 2 | 1 | 1 | 32 | 2 | out_proj | 10 | 20 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 54.5 | 17.7 | 88.5 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 4.19 | 0.01 | 0.13 | 0.00 | 0.00 | 0.00 | 6.16 | 22.5 | FSDP |
| 7 | 1 | 2 | 1 | 1 | 16 | 4 | full | 10 | 40 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 51.2 | 18.8 | 94.2 | 204.6 | 33 | 1.85 | 0.00 | 0.00 | 3.99 | 0.02 | 0.34 | 0.00 | 0.00 | 0.01 | 6.21 | 22.3 | FSDP |
| 8 | 1 | 1 | 1 | 1 | 64 | 2 | qkv_proj | 10 | 20 | 4 | 1.0 | 2.0 | 1.0 | 8.3 | 61.8 | 18.5 | 92.7 | 204.6 | 29 | 1.78 | 0.00 | 0.00 | 4.32 | 0.00 | 0.13 | 0.00 | 0.00 | 0.00 | 6.24 | 22.2 | FSDP |
| 9 | 1 | 1 | 1 | 1 | 32 | 4 | out_proj | 10 | 40 | 4 | 2.0 | 4.1 | 2.0 | 8.1 | 54.5 | 17.7 | 88.5 | 204.6 | 32 | 1.82 | 0.00 | 0.00 | 4.19 | 0.00 | 0.34 | 0.00 | 0.00 | 0.00 | 6.36 | 21.8 | FSDP |
| 10 | 1 | 1 | 1 | 1 | 16 | 8 | full | 10 | 80 | 4 | 4.1 | 8.1 | 4.1 | 7.9 | 51.2 | 18.8 | 94.2 | 204.6 | 33 | 1.85 | 0.00 | 0.00 | 3.99 | 0.00 | 0.76 | 0.00 | 0.00 | 0.01 | 6.60 | 21.0 | FSDP |
Column Legend
| Column | Description |
|---|---|
| TP / DP / PP / EP / FSDP / CP | Tensor / Data / Pipeline / Expert / Fully-Sharded-Data / Context parallelism degree |
| Remat | Rematerialization (activation checkpointing) policy |
| PDB | per_device_batch_size, the MaxText config value for batch size per device |
| MB | Actual physical per-device batch = PDB × TP × CP (batch not sharded by tensor/context axes) |
| GA | Gradient accumulation steps |
| W(GB) | Model weights per device (GB) |
| O(GB) | Optimizer states per device (GB), including Adam mu + nu |
| G(GB) | Gradients per device (GB) |
| FBuf | FSDP all-gather prefetch buffer (GB), peak = 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Act(GB) | Activation memory per device (GB), depends on micro batch and remat policy |
| Rsv | XLA overhead reserve (GB), 30% of modeled memory for comm buffers & HLO temps |
| Tot(GB) | Total HBM usage per device (GB) = W + O + G + FBuf + Act + Rsv |
| Trn(PF) | Training FLOPs per step (PetaFLOPs), useful fwd+bwd without remat overhead |
| Rmt% | Remat overhead percentage, extra compute from reactivation recomputation |
| Comp | Compute time (s), actual FLOPs / (num_devices x peak TFLOPs/device) |
| TP(s) | TP all-reduce communication time (s), non-overlappable |
| EP(s) | EP all-to-all communication time (s), non-overlappable |
| FSDP+ | FSDP exposed communication time (s), partially overlapped with compute (efficiency=85%) |
| DP(s) | DP gradient all-reduce time (s), non-overlappable |
| CP(s) | CP KV all-gather communication time (s), non-overlappable |
| PP+(s) | PP ppermute communication time (s), fully overlapped by XLA scheduler |
| Bub(s) | PP bubble idle time (s), formula: (PP-1)/(num_repeats x GA + PP-1) |
| Opt | Optimizer step time (s), memory-bound AdamW (28B/param, HBM_BW=3690GB/s) |
| Step(s) | Total step time (s) = Comp + all comm + Bub + Opt |
| MFU% | Model FLOPs Utilization (%), useful TFLOPs / (step_time x num_devices x peak TFLOPs) |
| Bottleneck | The dominant time component limiting throughput |
#1: TP=1 DP=4 PP=1 EP=1 FSDP=32 CP=1 remat=save_out_proj
Memory per device
| Component | Size (GB) | Detail |
|---|---|---|
| Weights | 2.03 | expert: 1.88 + non-expert: 0.15 |
| Optimizer | 4.06 | |
| Gradients | 2.03 | |
| FSDP buffer | 8.14 | all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Activations | 54.53 | micro_batch=10 x inflight=1, remat=save_out_proj |
| XLA reserve | 21.24 | 30% of modeled |
| Total | 92.02 | / 96 GiB HBM (96% used) |
Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40
FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%
Communication
| Type | Volume (GB) | Time (s) | Note |
|---|---|---|---|
| TP all-reduce | 0.0 | 0.00 | non-overlappable |
| EP all-to-all | 0.0 | 0.00 | non-overlappable |
| FSDP gather+RS | 754.9 | 4.06 total, 2.51 exposed | overlap=85% |
| DP all-reduce | 3.0 | 0.07 | non-overlappable |
| CP all-gather (KV) | 0.0 | 0.00 | non-overlappable |
| PP ppermute (ICI) | 0.0 | 0.00 | fully overlapped |
| Optimizer (AdamW) | 14.2 | 0.00 | memory-bound, 28B/param, HBM=3690GB/s |
Performance
Step time: 4.402s (compute=1.823 + comm=2.576 + optimizer=0.004)
MFU: 31.5%
Bottleneck: FSDP
#2: TP=1 DP=1 PP=1 EP=1 FSDP=128 CP=1 remat=save_qkv_proj
Memory per device
| Component | Size (GB) | Detail |
|---|---|---|
| Weights | 0.51 | expert: 0.47 + non-expert: 0.04 |
| Optimizer | 1.01 | |
| Gradients | 0.51 | |
| FSDP buffer | 8.33 | all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Activations | 61.80 | micro_batch=10 x inflight=1, remat=save_qkv_proj |
| XLA reserve | 21.65 | 30% of modeled |
| Total | 93.81 | / 96 GiB HBM (98% used) |
Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40
FLOPs: useful=204.6 PFLOPs, actual=263.4 PFLOPs, remat=+29%
Communication
| Type | Volume (GB) | Time (s) | Note |
|---|---|---|---|
| TP all-reduce | 0.0 | 0.00 | non-overlappable |
| EP all-to-all | 0.0 | 0.00 | non-overlappable |
| FSDP gather+RS | 773.2 | 4.16 total, 2.64 exposed | overlap=85% |
| DP all-reduce | 0.0 | 0.00 | non-overlappable |
| CP all-gather (KV) | 0.0 | 0.00 | non-overlappable |
| PP ppermute (ICI) | 0.0 | 0.00 | fully overlapped |
| Optimizer (AdamW) | 3.6 | 0.00 | memory-bound, 28B/param, HBM=3690GB/s |
Performance
Step time: 4.426s (compute=1.784 + comm=2.641 + optimizer=0.001)
MFU: 31.3%
Bottleneck: FSDP
#3: TP=1 DP=2 PP=1 EP=1 FSDP=64 CP=1 remat=save_out_proj
Memory per device
| Component | Size (GB) | Detail |
|---|---|---|
| Weights | 1.01 | expert: 0.94 + non-expert: 0.08 |
| Optimizer | 2.03 | |
| Gradients | 1.01 | |
| FSDP buffer | 8.27 | all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Activations | 54.53 | micro_batch=10 x inflight=1, remat=save_out_proj |
| XLA reserve | 20.06 | 30% of modeled |
| Total | 86.91 | / 96 GiB HBM (91% used) |
Batch: per_device_batch=10, micro_batch=10 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=40
FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%
Communication
| Type | Volume (GB) | Time (s) | Note |
|---|---|---|---|
| TP all-reduce | 0.0 | 0.00 | non-overlappable |
| EP all-to-all | 0.0 | 0.00 | non-overlappable |
| FSDP gather+RS | 767.1 | 4.12 total, 2.58 exposed | overlap=85% |
| DP all-reduce | 1.0 | 0.04 | non-overlappable |
| CP all-gather (KV) | 0.0 | 0.00 | non-overlappable |
| PP ppermute (ICI) | 0.0 | 0.00 | fully overlapped |
| Optimizer (AdamW) | 7.1 | 0.00 | memory-bound, 28B/param, HBM=3690GB/s |
Performance
Step time: 4.444s (compute=1.823 + comm=2.620 + optimizer=0.002)
MFU: 31.2%
Bottleneck: FSDP
#4: TP=1 DP=1 PP=1 EP=1 FSDP=64 CP=2 remat=save_out_proj
Memory per device
| Component | Size (GB) | Detail |
|---|---|---|
| Weights | 1.01 | expert: 0.94 + non-expert: 0.08 |
| Optimizer | 2.03 | |
| Gradients | 1.01 | |
| FSDP buffer | 8.27 | all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Activations | 54.53 | micro_batch=20 x inflight=1, remat=save_out_proj |
| XLA reserve | 20.06 | 30% of modeled |
| Total | 86.91 | / 96 GiB HBM (91% used) |
Batch: per_device_batch=10, micro_batch=20 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=80
FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%
Communication
| Type | Volume (GB) | Time (s) | Note |
|---|---|---|---|
| TP all-reduce | 0.0 | 0.00 | non-overlappable |
| EP all-to-all | 0.0 | 0.00 | non-overlappable |
| FSDP gather+RS | 767.1 | 4.12 total, 2.58 exposed | overlap=85% |
| DP all-reduce | 0.0 | 0.00 | non-overlappable |
| CP all-gather (KV) | 12.5 | 0.32 | non-overlappable |
| PP ppermute (ICI) | 0.0 | 0.00 | fully overlapped |
| Optimizer (AdamW) | 7.1 | 0.00 | memory-bound, 28B/param, HBM=3690GB/s |
Performance
Step time: 4.718s (compute=1.823 + comm=2.894 + optimizer=0.002)
MFU: 29.4%
Bottleneck: FSDP
#5: TP=1 DP=2 PP=1 EP=1 FSDP=32 CP=2 remat=save_out_proj
Memory per device
| Component | Size (GB) | Detail |
|---|---|---|
| Weights | 2.03 | expert: 1.88 + non-expert: 0.15 |
| Optimizer | 4.06 | |
| Gradients | 2.03 | |
| FSDP buffer | 8.14 | all-gather peak: 2 layers x full_layer_weight x (FSDP-1)/FSDP |
| Activations | 54.53 | micro_batch=20 x inflight=1, remat=save_out_proj |
| XLA reserve | 21.24 | 30% of modeled |
| Total | 92.02 | / 96 GiB HBM (96% used) |
Batch: per_device_batch=10, micro_batch=20 (PDB×TP×CP), GA=4, batch_par=128, effective_batch/dev=80
FLOPs: useful=204.6 PFLOPs, actual=269.1 PFLOPs, remat=+32%
Communication
| Type | Volume (GB) | Time (s) | Note |
|---|---|---|---|
| TP all-reduce | 0.0 | 0.00 | non-overlappable |
| EP all-to-all | 0.0 | 0.00 | non-overlappable |
| FSDP gather+RS | 754.9 | 4.06 total, 2.51 exposed | overlap=85% |
| DP all-reduce | 2.0 | 0.09 | non-overlappable |
| CP all-gather (KV) | 12.5 | 0.32 | non-overlappable |
| PP ppermute (ICI) | 0.0 | 0.00 | fully overlapped |
| Optimizer (AdamW) | 14.2 | 0.00 | memory-bound, 28B/param, HBM=3690GB/s |
Performance
Step time: 4.743s (compute=1.823 + comm=2.916 + optimizer=0.004)
MFU: 29.2%
Bottleneck: FSDP
Worst 10 configurations (slowest)
10 of 10 shown
| # | TP | DP | PP | EP | FSDP | CP | Remat | PDB | MB | GA | W(GB) | O(GB) | G(GB) | FBuf | Act(GB) | Rsv | Tot(GB) | Trn(PF) | Rmt% | Comp | TP(s) | EP(s) | FSDP+ | DP(s) | CP(s) | PP+(s) | Bub | Opt | Step(s) | MFU% | Bottleneck |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 2 | 1 | 2 | 4 | 8 | save_all | 1 | 8 | 40 | 8.7 | 17.5 | 8.7 | 4.0 | 19.3 | 17.5 | 75.8 | 204.6 | 0 | 1.39 | 0.00 | 4.68 | 33.47 | 0.38 | 0.47 | 0.00 | 0.00 | 0.02 | 40.41 | 3.4 | FSDP |
| 2 | 1 | 4 | 1 | 2 | 4 | 4 | save_all | 1 | 4 | 40 | 8.7 | 17.5 | 8.7 | 4.0 | 20.2 | 17.8 | 77.0 | 204.6 | 0 | 1.39 | 0.00 | 4.68 | 33.47 | 0.28 | 0.41 | 0.00 | 0.00 | 0.02 | 40.25 | 3.4 | FSDP |
| 3 | 1 | 8 | 1 | 2 | 4 | 2 | save_all | 1 | 2 | 40 | 8.7 | 17.5 | 8.7 | 4.0 | 22.1 | 18.3 | 79.4 | 204.6 | 0 | 1.39 | 0.00 | 4.68 | 33.47 | 0.19 | 0.32 | 0.00 | 0.00 | 0.02 | 40.07 | 3.5 | FSDP |
| 4 | 1 | 16 | 1 | 2 | 4 | 1 | save_all | 1 | 1 | 40 | 8.7 | 17.5 | 8.7 | 4.0 | 25.8 | 19.5 | 84.3 | 204.6 | 0 | 1.39 | 0.00 | 4.68 | 33.47 | 0.20 | 0.00 | 0.00 | 0.00 | 0.02 | 39.76 | 3.5 | FSDP |
| 5 | 1 | 8 | 1 | 1 | 16 | 1 | save_all | 1 | 1 | 40 | 4.1 | 8.1 | 4.1 | 7.9 | 25.8 | 15.0 | 64.9 | 204.6 | 0 | 1.39 | 0.00 | 0.00 | 38.11 | 0.09 | 0.00 | 0.00 | 0.00 | 0.01 | 39.59 | 3.5 | FSDP |
| 6 | 1 | 2 | 1 | 1 | 8 | 8 | save_all | 1 | 8 | 40 | 8.1 | 16.2 | 8.1 | 7.3 | 19.3 | 17.7 | 76.8 | 204.6 | 0 | 1.39 | 0.00 | 0.00 | 35.49 | 0.35 | 0.47 | 0.00 | 0.00 | 0.02 | 37.71 | 3.7 | FSDP |
| 7 | 1 | 4 | 1 | 1 | 8 | 4 | save_all | 1 | 4 | 40 | 8.1 | 16.2 | 8.1 | 7.3 | 20.2 | 18.0 | 78.1 | 204.6 | 0 | 1.39 | 0.00 | 0.00 | 35.49 | 0.26 | 0.41 | 0.00 | 0.00 | 0.02 | 37.56 | 3.7 | FSDP |
| 8 | 1 | 2 | 1 | 4 | 2 | 8 | save_all | 1 | 8 | 40 | 10.0 | 19.9 | 10.0 | 1.9 | 19.3 | 18.3 | 79.4 | 204.6 | 0 | 1.39 | 0.00 | 4.50 | 30.74 | 0.43 | 0.47 | 0.00 | 0.00 | 0.02 | 37.55 | 3.7 | FSDP |
| 9 | 1 | 8 | 1 | 1 | 8 | 2 | save_all | 1 | 2 | 40 | 8.1 | 16.2 | 8.1 | 7.3 | 22.1 | 18.6 | 80.5 | 204.6 | 0 | 1.39 | 0.00 | 0.00 | 35.49 | 0.18 | 0.32 | 0.00 | 0.00 | 0.02 | 37.38 | 3.7 | FSDP |
| 10 | 1 | 4 | 1 | 4 | 2 | 4 | save_all | 1 | 4 | 40 | 10.0 | 19.9 | 10.0 | 1.9 | 20.2 | 18.6 | 80.7 | 204.6 | 0 | 1.39 | 0.00 | 4.50 | 30.74 | 0.32 | 0.41 | 0.00 | 0.00 | 0.02 | 37.38 | 3.7 | FSDP |
Summary
Best config: TP=1 DP=4 PP=1 EP=1 FSDP=32 CP=1 remat=save_out_proj
Best step time: 4.402s, MFU: 31.5%
MFU range: 3.4% - 31.5%
Step time range: 4.402s - 40.406s