Memory Management for
Optimizing LLMs

This post is inspired by EleutherAI's Transformer Math 101 and Jiayi Pan's VRAM Estimation notes. Originally drafted after discussion with Jiayi Pan and revised in March 2026 for better visual effects.

Training a large language model means fitting everything the GPU needs — weights, optimizer buffers, gradients, and intermediate computations — into a fixed amount of VRAM. Understanding what each component costs is a prerequisite for reasoning about OOM errors. This section walks through the four main consumers of GPU memory during a training step.

What Lives on the GPU

Figure 1. Adjust model size, optimizer, and precision to see how GPU memory is consumed. Green = fits, red = OOM. The four cards break down each component's cost.

Let \(P\) denote the number of model parameters. During mixed-precision training (the standard practice for LLMs), the GPU holds four categories of data:

Model Parameters

The weights used in the forward pass. In mixed-precision training, the forward and backward passes run in half precision (fp16 or bf16), so the live weights consume:

\[\text{Model memory} = 2P \text{ bytes}\]

(2 bytes per parameter for bf16/fp16.)

Optimizer States

The optimizer maintains its own buffers that persist across training steps. For AdamW (the standard choice for LLM training), the update rule at each step \(t\) is:

\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) \, g_t\] \[v_t = \beta_2 v_{t-1} + (1 - \beta_2) \, g_t^2\] \[\theta_t = \theta_{t-1} - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \, \theta_{t-1} \right)\]

where \(\hat{m}_t, \hat{v}_t\) are bias-corrected estimates. Each of these quantities must be stored per parameter:

  • fp32 master copy of \(\theta\) — 4 bytes/param. The optimizer updates weights in fp32 for numerical stability, then casts back to bf16 for the next forward pass.
  • First moment \(m_t\) — 4 bytes/param. The exponential moving average of gradients (momentum).
  • Second moment \(v_t\) — 4 bytes/param. The exponential moving average of squared gradients (variance).
\[\text{Optimizer memory} = \underbrace{4P}_{\theta^{\text{fp32}}} + \underbrace{4P}_{m} + \underbrace{4P}_{v} = 12P \text{ bytes}\]

This is typically the single largest memory consumer for model-related storage. For a 7B parameter model: \(12 \times 7 \times 10^9 = 84\) GB just for optimizer states.

Other optimizers use less: SGD with momentum costs \(8P\) bytes (fp32 copy + momentum), and 8-bit optimizers like those in bitsandbytes reduce the moments to 1 byte each, costing \(6P\) bytes.

Gradients

The gradient tensor for each parameter, computed during the backward pass. In mixed precision:

\[\text{Gradient memory} = 2P \text{ bytes}\]

(Gradients are stored in the same precision as the model — bf16/fp16.)

Activations

The intermediate outputs of each layer, saved during the forward pass so they can be reused in the backward pass. Unlike the previous three, activation memory scales with the data being processed — specifically with batch size, sequence length, and model depth.

For a transformer with \(n_{\text{layers}}\) layers, hidden size \(d\), attention heads \(a\), sequence length \(s\), and microbatch size \(b\) (per GPU), the activation memory without recomputation is approximately:

\[\text{Activation memory} \approx s \cdot b \cdot d \cdot n_{\text{layers}} \cdot \left(10 + \frac{24}{t} + \frac{5as}{dt}\right) \text{ bytes}\]

where \(t\) is the tensor parallelism degree (1 if no tensor parallelism). The three terms inside the parentheses correspond to three groups of saved tensors per layer, factored by whether they are split across \(t\) GPUs:

The \(10\) term — activations that are not split by tensor parallelism (full \(d\)-dimensional tensors):

Saved tensor Shape Bytes (bf16)
Layer norm input (before self-attention) \([s, b, d]\) \(2sbd\)
Layer norm input (before MLP) \([s, b, d]\) \(2sbd\)
Self-attention output (before residual add) \([s, b, d]\) \(2sbd\)
MLP output (before residual add) \([s, b, d]\) \(2sbd\)
Dropout mask (after self-attention) \([s, b, d]\) \(sbd\)
Dropout mask (after MLP) \([s, b, d]\) \(sbd\)
Subtotal   \(10 \cdot sbd\)

The \(24/t\) term — activations split across \(t\) tensor-parallel GPUs (each GPU stores \(d/t\)):

Saved tensor Shape per GPU Bytes (bf16)
Q projection output \([s, b, d/t]\) \(2sbd/t\)
K projection output \([s, b, d/t]\) \(2sbd/t\)
V projection output \([s, b, d/t]\) \(2sbd/t\)
Attention value output \([s, b, d/t]\) \(2sbd/t\)
MLP first linear output (input to GeLU) \([s, b, 4d/t]\) \(8sbd/t\)
GeLU output \([s, b, 4d/t]\) \(8sbd/t\)
Subtotal   \(24 \cdot sbd/t\)

The \(5as/(dt)\) term — attention matrices that scale as \(O(s^2)\):

Saved tensor Shape per GPU Bytes
Attention scores (pre-softmax) \([b, a/t, s, s]\) \(2bas^2/t\)
Attention probabilities (post-softmax) \([b, a/t, s, s]\) \(2bas^2/t\)
Attention dropout mask \([b, a/t, s, s]\) \(bas^2/t\)
Subtotal   \(5bas^2/t = sbd \cdot \frac{5as}{dt}\)

This last group is why long sequences are expensive: at \(s = 8192\) with \(d = 4096\), the \(5as/(dt)\) term can exceed the other two combined.

With full activation recomputation (gradient checkpointing), only the input to each layer is saved — everything else is recomputed during the backward pass. This reduces activation memory to \(2 \cdot s \cdot b \cdot d \cdot n_{\text{layers}}\) bytes, at the cost of roughly doubling the forward-pass compute.

Putting It Together

For a single GPU with no parallelism, the total training memory is:

\[\text{Total} = \underbrace{2P}_{\text{params}} + \underbrace{12P}_{\text{AdamW}} + \underbrace{2P}_{\text{grads}} + \underbrace{\text{Activations}(b, s)}_{\text{scales with data}}\]

The first three terms sum to \(16P\) bytes — fixed regardless of batch size. For a 7B model, that’s ~112 GB before any data is processed. The fourth term is the only one you can control at runtime by changing how much data you feed per forward pass.

This is where memory optimization techniques come in.

Optimizing GPU Memory Usage

Every memory optimization technique trades something for memory savings — there is no free lunch. The table below summarizes the main approaches, what they save, and what they cost.

Mixed Precision Training

What it does. Store different GPU memory residents at different precisions. The key insight is that the four memory components from the previous section have different numerical sensitivity — activations and gradients are stable in half precision, but optimizer updates are not. The standard recipe:

Stored on GPU Precision Memory Why
Model parameters bf16 (live) + fp32 (master) \(2P + 4P\) Forward/backward use the bf16 copy; optimizer updates the fp32 master
Activations bf16/fp16 \(2 \cdot sbd\) per tensor Half precision is sufficient for intermediate computations
Gradients bf16/fp16 \(2P\) Same precision as the forward pass
Optimizer states (\(m_t, v_t\)) fp32 \(8P\) Momentum and variance need fp32 for numerical stability

The flow each training step: cast fp32 master weights → bf16, run forward/backward in bf16 (producing bf16 activations and gradients), pass gradients to the optimizer which updates the fp32 master weights, then refresh the bf16 copy.

The tradeoff. fp16 has a narrow dynamic range (max \(\sim 6.5 \times 10^4\)) and can cause loss spikes or divergence without careful loss scaling. bf16 (available on Ampere+ GPUs) has the same exponent range as fp32, making it more stable but slightly less precise in the mantissa. Either way, the optimizer states remain in fp32 — so mixed precision saves on the live parameters and gradients (\(4P\) bytes saved) but not on the dominant \(12P\) optimizer cost.

\[\text{Mixed precision savings: } 4P \text{ bytes (params + grads halved)}\]

Gradient Checkpointing (Activation Recomputation)

What it does. Instead of saving all intermediate activations during the forward pass, save only the input to each transformer layer. During the backward pass, recompute each layer’s intermediates from the saved input before computing gradients.

The tradeoff. Activation memory drops from \(sbd \cdot n_{\text{layers}} \cdot (10 + 24/t + 5as/(dt))\) to just \(2sbd \cdot n_{\text{layers}}\) — typically a 5-10x reduction. The cost is roughly doubling the forward-pass compute, since every layer’s forward pass runs twice (once during forward, once during backward). Wall-clock training time increases by ~30-40% in practice (backward is more expensive than forward, so the extra forward pass is a smaller fraction of total time).

Figure 2. Adjust layers, sequence length, batch size, and hidden dim to see per-layer activation breakdown. Toggle gradient checkpointing to compare memory savings. Note how attention scores scale as O(s²).

Microbatching and Gradient Accumulation

Consider a training step where one GPU receives \(N\) sequences of varying lengths \(\{l_1, l_2, \ldots, l_N\}\), totaling \(L = \sum_i l_i\) tokens. The GPU must compute forward and backward passes over all \(N\) sequences before a single optimizer step.

The naive approach processes everything at once: pack all \(N\) sequences into a single tensor, forward pass, backward pass, optimizer step. But this requires holding all activations for \(N\) sequences simultaneously — which can easily exceed GPU memory even when the fixed costs (parameters + optimizer + gradients) fit comfortably.

Gradient accumulation solves this: split the \(N\) sequences into \(K\) microbatches, process each sequentially, accumulate gradients, then step.

\[\nabla_\theta \mathcal{L} = \sum_{k=1}^{K} \nabla_\theta \mathcal{L}_k\]

The final gradient is identical to the single-batch case. The optimizer sees the same update. But memory usage is dramatically different: activations are created and destroyed per microbatch. Only one microbatch’s activations exist at a time. So peak memory is determined by the largest microbatch, not the total batch.

Figure 3. Animation of gradient accumulation. Watch how activations appear during each forward pass and are consumed during backward, while gradients accumulate across microbatches. Only one microbatch's activations exist at a time.

A bin-packing algorithm (First-Fit Decreasing) groups the \(N\) sequences into microbatches, where each microbatch’s total token count must not exceed max_tokens_per_mb.

Example: 84 sequences averaging 5000 tokens each (420K total tokens) on one GPU.

max_tokens_per_mb Microbatches Seqs/MB Peak activation memory
65536 ~7 ~12 High — 12 sequences × padded length
16384 ~26 ~3 Low — 3 sequences × padded length
8192 ~52 ~1-2 Minimal — 1-2 sequences × padded length

With max_tokens_per_mb=65536, each microbatch holds ~12 sequences. The forward pass must store activations for all 12 simultaneously. With 16384, only ~3 sequences are processed at once — 4x less activation memory.

The padding tax. There is a subtlety that makes larger microbatches even worse: within a microbatch, sequences are padded to the length of the longest sequence in that group, because GPU tensor operations need rectangular shapes. Consider a microbatch with 3 sequences of lengths [3000, 5000, 8000]. All three are padded to 8000:

\[\text{Effective tokens} = 3 \times 8000 = 24000 \quad \text{(vs actual 16000)}\]

That’s 50% wasted compute and memory on padding tokens. Larger microbatches are more likely to contain outlier-length sequences, making the padding problem worse. Smaller max_tokens_per_mb means fewer sequences per group, so the longest sequence in each group is closer to the average — less padding waste, lower peak memory.

The tradeoff. Smaller microbatches reduce peak activation memory and padding waste, but increase the number of forward/backward passes per step — adding kernel launch overhead and reducing GPU utilization. Extremely small microbatches (1-2 sequences) also underutilize the GPU’s parallel compute units.

Figure 4. Drag the slider to change max_tokens_per_mb and see how sequences are packed into microbatches. Purple = actual tokens, pink = padding waste. The red-outlined microbatch determines peak activation memory.

Optimizer State Compression

What it does. Replace the fp32 optimizer states with lower-precision versions. 8-bit Adam (e.g., bitsandbytes) quantizes the first and second moments to int8, reducing optimizer memory from \(12P\) to \(6P\) bytes. Adafactor eliminates the second moment entirely by factoring it into row and column statistics, reducing to \(\sim 6P\) bytes with a different approximation.

The tradeoff. 8-bit Adam introduces quantization noise into the moment estimates. In practice, dynamic quantization with block-wise scaling preserves most of the convergence properties of full-precision Adam. Adafactor can diverge on some tasks and requires more careful hyperparameter tuning. SGD with momentum costs \(8P\) bytes but converges more slowly and with less stability on large-scale language model training.

Parameter-Efficient Fine-Tuning (LoRA / QLoRA)

What it does. Freeze the base model weights and train only small low-rank adapter matrices. LoRA adds rank-\(r\) matrices \(A \in \mathbb{R}^{d \times r}\) and \(B \in \mathbb{R}^{r \times d}\) to each target layer, with \(r \ll d\). Trainable parameters drop to \(\sim 1\text{-}2\%\) of the original. QLoRA goes further: quantize the frozen base model to 4-bit (0.5 bytes/param), reducing parameter memory from \(2P\) to \(0.5P\).

The tradeoff. Optimizer states, gradients, and activations now scale with the adapter size, not the full model — massive memory savings. The cost is reduced expressiveness: the adapter can only learn changes within the low-rank subspace. For many fine-tuning tasks this is sufficient, but for pretraining or large distribution shifts, full-rank updates are necessary.

\[\text{QLoRA memory} \approx \underbrace{0.5P}_{\text{4-bit base}} + \underbrace{12 \cdot 0.02P}_{\text{adapter optimizer}} + \underbrace{2 \cdot 0.02P}_{\text{adapter grads}} \approx 0.78P \text{ bytes}\]

Offloading

What it does. Move optimizer states or parameters to CPU RAM (or even NVMe storage) and swap them to GPU only when needed. DeepSpeed ZeRO-Offload and ZeRO-Infinity implement this transparently.

The tradeoff. PCIe bandwidth between CPU and GPU is 1-2 orders of magnitude slower than GPU memory bandwidth. Offloading trades training speed for the ability to train models that would otherwise not fit at all. It is a last resort, useful for training very large models on limited GPU hardware.

Summary

Technique Memory saved Cost
Mixed precision (bf16) \(4P\) bytes Slight precision loss
Gradient checkpointing ~5-10x activation reduction ~30-40% more compute
8-bit optimizer \(6P\) bytes Quantization noise
QLoRA ~\(15P\) bytes vs full fine-tuning Reduced expressiveness
CPU offloading Moves optimizer to RAM Major speed reduction
Microbatching Controls peak activation memory More kernel launches, lower GPU utilization

None of these are free — and all of the above are single-GPU techniques. When the model still doesn’t fit, or when you need to scale to dozens or hundreds of GPUs, parallelism strategies distribute the memory (and compute) across devices. The next section covers these.

Parallelism Strategies

The single-GPU optimizations above can only go so far. A 70B model requires \(16 \times 70 \times 10^9 = 1120\) GB for parameters + optimizer + gradients alone — no single GPU comes close. Parallelism distributes memory and compute across multiple devices. The four main strategies are complementary and are typically combined in practice.

Recall from What Lives on the GPU that total per-GPU memory is:

\[\underbrace{2P}_{\text{params}} + \underbrace{12P}_{\text{optimizer}} + \underbrace{2P}_{\text{grads}} + \underbrace{\text{Act}(b, s)}_{\text{activations}} = 16P + \text{Act}\]

Each parallelism strategy targets different terms in this equation.

Data Parallelism (DP)

Idea. Replicate the entire model on each of \(N\) GPUs. Each GPU processes a different data shard, computes gradients locally, then all-reduces gradients before the optimizer step. The result is mathematically identical to single-GPU training with \(N \times\) the batch size.

Per-GPU memory:

\[\underbrace{2P}_{\text{params}} + \underbrace{12P}_{\text{optimizer}} + \underbrace{2P}_{\text{grads}} + \text{Act}(b/N, s)\]

DP does not reduce model-related memory — every GPU still holds the full \(16P\) bytes. It only reduces activations by shrinking each GPU’s local batch from \(b\) to \(b/N\).

Communication. One all-reduce of gradients (\(2P\) bytes) per step, which can be overlapped with backward computation using bucketed gradient all-reduce (as in PyTorch DDP).

When to use. When the model fits on a single GPU but you want higher throughput. DP is the simplest and most efficient form of parallelism — always the first thing to try.

ZeRO / FSDP (Sharded Data Parallelism)

The key insight behind ZeRO (Zero Redundancy Optimizer) is that vanilla DP is wasteful: every GPU stores an identical copy of optimizer states, gradients, and parameters. ZeRO shards these across \(N\) data-parallel GPUs in three progressive stages:

Stage What is sharded Per-GPU memory Communication per step
ZeRO-1 Optimizer states \(2P + 2P + 12P/N + \text{Act}\) Same as DP (gradient all-reduce)
ZeRO-2 Optimizer states + gradients \(2P + 2P/N + 12P/N + \text{Act}\) Reduce-scatter gradients (similar cost to all-reduce)
ZeRO-3 Optimizer states + gradients + parameters \(2P/N + 2P/N + 12P/N + \text{Act}\) All-gather params before each forward/backward layer

ZeRO-1 is nearly free: sharding optimizer states across \(N\) GPUs reduces the dominant \(12P\) term to \(12P/N\), with no extra communication beyond the standard gradient all-reduce. For a 7B model on 8 GPUs, optimizer memory drops from 84 GB to ~10.5 GB per GPU.

ZeRO-2 additionally shards gradients. Each GPU only stores gradients for its shard’s parameters, then reduce-scatters (instead of all-reducing) so each rank accumulates only the gradients it needs. Communication cost is similar to all-reduce.

ZeRO-3 (equivalent to PyTorch FSDP) shards everything — the full \(16P\) becomes \(16P/N\) per GPU. The cost: parameters must be all-gathered before each layer’s forward and backward pass, and freed immediately after. This turns every layer into a communication event.

\[\text{ZeRO-3 per-GPU model memory} = \frac{16P}{N} \text{ bytes}\]

Practical notes. EleutherAI reports that ZeRO-3 is “too communication-heavy at large scales” and prefers ZeRO-1 combined with tensor and pipeline parallelism. ZeRO-1 is the default for most training runs because it targets the largest memory consumer (optimizer states) with minimal overhead. ZeRO-3/FSDP shines when GPU count is moderate and interconnect is fast (e.g., 8 GPUs within a single node on NVLink).

Tensor Parallelism (TP)

Idea. Split individual weight matrices across \(t\) GPUs so that each GPU computes a slice of every layer. For a linear layer \(Y = XW\), the weight \(W\) is column- or row-split, each GPU computes its portion, and the results are combined via all-reduce or all-gather.

Per-GPU memory:

\[\frac{2P}{t} + \frac{12P}{t} + \frac{2P}{t} + \text{Act}(b, s, t) = \frac{16P}{t} + \text{Act}\]

where activations are also partially reduced — the \(24/t\) and \(5as/(dt)\) terms in the activation formula reflect this splitting.

Communication. Two all-reduce operations per transformer layer (one in the attention block, one in the MLP), each communicating \(O(bsd)\) activation tensors. This happens on the critical path — computation cannot proceed until the all-reduce completes. This is why TP requires NVLink (~900 GB/s) rather than PCIe (~64 GB/s) or network interconnects.

Typical TP degree. TP is usually set to the number of GPUs within a single node (e.g., \(t = 8\) for an 8-GPU node with NVLink). Going beyond a node boundary is impractical because the inter-node bandwidth is too low for the frequent all-reduces.

Pipeline Parallelism (PP)

Idea. Partition the model’s layers into \(p\) stages, assigning consecutive layers to different GPUs. GPU 1 runs layers 1–\(L/p\), GPU 2 runs layers \(L/p + 1\)–\(2L/p\), and so on.

Per-GPU memory:

\[\frac{2P}{p} + \frac{12P}{p} + \frac{2P}{p} + \text{Act} = \frac{16P}{p} + \text{Act}\]

Each GPU only stores parameters, optimizer states, and gradients for its assigned layers. Activation memory depends on the schedule (see below).

Communication. Point-to-point sends of activation tensors (\(O(bsd)\)) between adjacent stages — much less volume than TP’s all-reduces, and tolerant of lower-bandwidth interconnects.

The bubble problem. With naive sequential execution, only one stage is active at a time — the other \(p - 1\) GPUs are idle. The pipeline bubble is the fraction of time wasted:

\[\text{Bubble fraction} = \frac{p - 1}{m + p - 1}\]

where \(m\) is the number of microbatches. To keep the bubble small (say, < 5%), you need \(m \gg p\). GPipe and PipeDream use different strategies:

  • GPipe: Runs all \(m\) microbatch forward passes, then all \(m\) backward passes. Simple, but all activations for all microbatches must be held simultaneously, increasing memory by a factor of \(m\).
  • 1F1B (PipeDream): Interleaves forward and backward passes so that each GPU holds activations for at most \(p\) microbatches (instead of \(m\)). This significantly reduces activation memory at the cost of more complex scheduling.

Combining Strategies: 3D Parallelism {#3d-parallelism}

In practice, large-scale training combines all three: TP within a node, PP across nodes, and DP (with ZeRO-1) across pipeline-parallel replicas. For \(N\) total GPUs with TP degree \(t\), PP degree \(p\), and DP degree \(d = N / (t \cdot p)\):

\[\text{Per-GPU model memory} \approx \frac{16P}{t \cdot p} + \frac{12P}{d} \cdot \left(\frac{1}{t \cdot p} - \frac{1}{t \cdot p}\right)\]

More concretely with ZeRO-1:

\[\text{Per-GPU memory} = \frac{2P}{t \cdot p} + \frac{12P}{t \cdot p \cdot d} + \frac{2P}{t \cdot p} + \text{Act}(b_{\text{local}}, s, t)\]

Example: 70B model on 64 GPUs (8 nodes × 8 GPUs/node):

  • \(t = 8\) (TP within each node), \(p = 4\) (PP across 4 nodes), \(d = 2\) (2 DP replicas)
  • Parameters per GPU: \(2 \times 70\text{B} / (8 \times 4) = 4.4\) GB
  • Optimizer per GPU: \(12 \times 70\text{B} / (8 \times 4 \times 2) = 13.1\) GB
  • Gradients per GPU: \(2 \times 70\text{B} / (8 \times 4) = 4.4\) GB
  • Model-related total: ~21.9 GB per GPU — comfortably fits on an 80 GB A100
Strategy Splits Communication Interconnect Reduces
DP Data All-reduce gradients Any Activation memory (via smaller local batch)
ZeRO-1 Optimizer states Same as DP Any Optimizer memory
ZeRO-3 / FSDP Everything All-gather per layer NVLink preferred All model memory
TP Weight matrices All-reduce per layer NVLink required All model memory + activations
PP Layers Point-to-point activations Network OK All model memory

Checkpointing to Disk

GPU memory is volatile — if a node crashes, a job gets preempted, or you simply want to pause and resume later, everything on the GPU is lost. Checkpointing saves the training state to persistent storage so that a run can be restarted from where it left off without retraining from scratch.

A complete checkpoint must contain everything needed to produce bit-identical training dynamics from the point of save. This means more than just the model weights:

Component What it contains Size (7B model, AdamW, bf16) Why it’s needed
Model parameters (fp32 master) The fp32 master copy of all weights \(4P = 28\) GB The authoritative weights; bf16 copies are derived from these
Optimizer states First moment \(m_t\), second moment \(v_t\), step count \(8P = 56\) GB Without these, the optimizer restarts with zero momentum/variance — causing a spike in loss and effectively wasting the warmup
Learning rate scheduler state Current step, warmup progress, decay schedule Negligible Ensures the learning rate continues from the correct position
RNG states Random seeds for all GPUs, dropout masks, data shuffling Negligible Required for exact reproducibility; without these, the resumed run diverges from the original
Data loader state Current epoch, sample index, shuffle order Small Prevents re-training on already-seen data or skipping unseen data
Gradient scaler state (fp16 only) Current loss scale, backoff count Negligible fp16 training uses dynamic loss scaling; resetting it causes unnecessary scale search

What You Can Skip

Gradients do not need to be saved. They are recomputed from scratch at the start of each training step. Saving mid-step (between microbatches during gradient accumulation) would require saving the partially accumulated gradients, but this is rarely done — it’s simpler to checkpoint only at step boundaries.

Activations do not need to be saved. They are transient, created during the forward pass and consumed during the backward pass within a single step.

The bf16 model copy does not need to be saved. It is deterministically derived from the fp32 master weights by casting.

Checkpoint Size

The dominant cost is model parameters + optimizer states. For AdamW in mixed precision:

\[\text{Checkpoint size} \approx 4P + 8P = 12P \text{ bytes}\]

For a 7B model, that’s ~84 GB per checkpoint. A 70B model produces ~840 GB checkpoints. At typical save frequencies (every few hundred steps), this accumulates quickly — a 70B training run saving every 500 steps for 100K steps produces ~168 TB of checkpoints if none are pruned.

Strategies to Reduce Checkpoint Cost

Async checkpointing. Saving 84 GB to networked storage (e.g., NFS, S3) can take minutes. Synchronous saves stall training. Modern frameworks (PyTorch’s torch.distributed.checkpoint, DeepSpeed) copy the state to CPU memory asynchronously and write to disk in a background thread, overlapping I/O with the next training step.

Sharded checkpointing. With data or model parallelism, each GPU saves only its own shard of the state. This parallelizes the I/O across all nodes and avoids gathering the full state onto a single machine. The downside is that loading requires the same parallelism configuration — resharding is needed if you change the number of GPUs.

Save only what changed. Some systems support incremental or delta checkpoints, saving only the difference from the previous checkpoint. This is most useful when checkpoints are frequent and the model changes slowly between saves.

Pruning old checkpoints. Keep the last \(k\) checkpoints and delete older ones. Optionally keep “milestone” checkpoints at longer intervals (e.g., every 10K steps) for evaluation or fallback.