profile pic
⌘ '
raccourcis clavier

details step-by-step reproduction from training scaling inference (Grattafiori et al., 2024)

pre-train 405B on 15.6T tokens with 8K context windows.

The data mix: 50% of tokens corresponding to general knowledge, 25% mathematical and reasoning tokens, 17% code tokens, and 8% multilingual tokens.

The also implement annealing data to improve quality (Blakeney et al., 2024)

They also run their own scaling law calculations, instead of using Chinchilla constant

Architecture-wise, nothing special, pure Transformers with Group-Query Attention and FFN

8B70B405B
Layers3280126
Model Dimension4096819216384
FFN Dimension143362867253248
Attention Heads3264128
Key/Value Heads888
Peak Learning Rate3*10-41.5*10-48*10-5
Activation FunctionSwiGLU
Vocabulary Size128000
Positional EmbeddingsRoPE θ = 500000

Training config:

GPUsTPCPPPDPSeq. Len.Batch size/DPTokens/BatchTFLOPs/GPUBF16 MFU
8,1928{,}1928811161664648,1928{,}192323216M16\mathrm{M}43043043%43\%
16,38416{,}384881116161281288,1928{,}192161616M16\mathrm{M}40040041%41\%
16,38416{,}384881616161688131,072131{,}072161616M16\mathrm{M}38038038%38\%
  • 16K H100 clusters (given that this is a production clusters instead of research clusters)
    • 8 pods with 3072 GPUs per pods but around 1:7 oversubscription ratios (or 7x lower bandwidth)
  • took around 54 days for pre-training
  • Theretical FLOPs for H100 is 1,978 TFLOPs BF16
  • training days can be calculated as: Training time days=total tokensthroughput tokens per sec86400\text{Training time days} = \frac{\text{total tokens}}{\text{throughput tokens per sec} * 86400}
  • Model FLOPs utilisation is usually global batch sizemodel FLOPstraining step timenGPUspeak GPU FLOPs\frac{\text{global batch size} * \text{model FLOPs}}{\text{training step time} * \text{nGPUs} * \text{peak GPU FLOPs}}
    • 38-43% utilization
  • Schedule:
    • linear warmup of 8000 steps
    • peak LR at 8×1058 \times 10^{-5} with Cosine LR scheduler to 8×1078 \times 10^{-7} at 1.2M steps
      • initial batch size of 4M4M tokens with seq_length=4096
      • double to batch size of 8M8M sequences of 8192 tokens after pretraining 252M252M tokens
      • double to batch size of 16M16M sequences of 8192 tokens after pretraining 2.87T2.87T tokens
  • Network configuration:
    • a variants of NCCL (NCCLX)
    • RDMA over Converged Ethernet (RoCE) fabric based on the Arista 7800 and Minipack2 Open Compute Project4 OCP rack.
    • RoCE and Infiniband clusters
    • Topology:
      • Three layers of Clos network
  • Training recipe: 4D parallelism with FSDP
    • tensor parallelism: split individual weights tensors to multiple chunks on different devices
    • pipeline parallelism: partition models vertically into stages by layers so different devices can process in parallel different stages of the full model pipeline
    • context parallelism: divides input context into segments; reducing memory bottleneck for long sequence inputs
    • FSDP: shards the model, optimizer, and gradients while implementing data parallelism (process data on multiple GPUs and synchronize per training steps)
      • They also do some network-aware parallelism configuration, but essentially they do all-gather
      • FSDP in Zero-2 mode, not Zero-3 mode. I.e., they keep the weight tensors materialized after the forward pass instead of re-gathering them in backward.

Bibliographie

  • Blakeney, C., Paul, M., Larsen, B. W., Owen, S., & Frankle, J. (2024). Does your data spark joy? Performance gains from domain upsampling at the end of training. arXiv preprint arXiv:2406.03476 [arxiv]
  • Grattafiori, A., Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., Mathur, A., Schelten, A., Vaughan, A., Yang, A., Fan, A., Goyal, A., Hartshorn, A., Yang, A., Mitra, A., Sravankumar, A., Korenev, A., Hinsvark, A., … Ma, Z. (2024). The Llama 3 Herd of Models. arXiv preprint arXiv:2407.21783 [arxiv]