Jaisidh Singh

The Ultra-Scale Playbook vol-3: DeepSpeed ZeRO

Zero Redundancy Optimiser

TLDR: Instead of naive data parallelism,

3 stages of ZeRO

  1. ZeRO-1: only partition optimiser states
  2. ZeRO-2: partition optimiser states + gradients
  3. ZeRO-3: partition optimiser states + gradients + params

Given the model parameter count N, mixed precision training with Adam dictates the following memory usage:

For efficiency let's keep grad-accumulation in bf16 and so total memory usage becomes 2N+2N+12N. Now given data parallel degree Nd,

ZeRO-1: only shard optimiser states

💡 To update each chunk during reduce_scatter, only that chunk (across different microbatches) is needed per machine.

What motivates ZeRO-2: Why not accumulate a chunk on all data during back-prop, and then only store the grads required for the optimiser step? That eliminates the need to always store all the grads.

ZeRO-2: shard optimiser states + grad

Only perform reduce_scatter during back-prop. Now only 1/Nd-th of the gradients are needed in memory, freeing up memory and giving us a much better memory footprint of 2N+(2N+12N)/Nd.

What motivates ZeRO-3: Distributing the params across DP ranks can make forward pass possible by doing an all_gather for each microbatch per DP rank. Think of it this way: we temporarily "stitch" all the shards of a layer together, forward pass a microbatch through it, then flush the gathered shards to keep only 1/Nd params in memory.

ZeRO-3: shard everything

💡 ZeRO-3 requires 2·num\_layers1 additional calls to all_gather w.r.t ZeRO-2, and each comes with a small base latency. Also, we gather all the shards once in the forward pass and once during back-prop, incurring a communication tax of N each time. Adding another communication tax of N from the reduce_scatter called during back-prop, our total communication cost is 3N compared to 2N in ZeRO-2.

💡 While this may seem like a lot of overhead due to communication, in practice we use prefetching to compensate for this: simply all_gather weights for the next layer when we forward pass through the current layer. Similarly, all_gather weights of the previous layer while back-propping through the current layer.

When will this fail: when our DP dimension exceeds 512, the communication overhead will become too large due to ring latency, and our overlap will fail. Need to think of something else at those scales.