einer Serie über verteilte KI über mehrere GPUs:
Einführung
Im vorherigen Beitrag haben wir gesehen, wie Distributed Knowledge Parallelism (DDP) das Coaching beschleunigt, indem es Stapel auf GPUs aufteilt. DDP löst das Durchsatzproblem, bringt jedoch eine neue Herausforderung mit sich: Speicherredundanz.
In Vanilla DDP enthält jede GPU eine vollständige Kopie der Modellparameter, Verläufe und Optimiererzustände. Bei großen Modellen wie GPT-3 (175B-Parameter) führt diese Redundanz zu einer großen Verschwendung wertvollen VRAMs.

ZeRO (Zero Redundancy Optimizer) löst dieses Drawback. Es gibt drei Ebenen:
- ZeRO-1 Partitionen nur Optimiererzustände
- ZeRO-2 Partitioniert Optimiererzustände + Farbverläufe
- ZeRO-3 Partitioniert Optimiererzustände + Verläufe + Modellparameter
ZeRO ist keine Parallelitätstechnik, da alle GPUs immer noch die gleichen Vorwärts- und Rückwärtsdurchläufe ausführen. Es ist ein Speicheroptimierung Strategie, die Redundanz zwischen GPUs eliminiert und es Ihnen ermöglicht, größere Modelle auf derselben {Hardware} zu trainieren.
Das Speicherproblem in DDP
Lassen Sie uns aufschlüsseln, was während des Trainings tatsächlich Speicher verbraucht. Für ein Modell mit Parametern:
- Modellparameter: Werte (die Gewichte Ihres neuronalen Netzwerks)
- Farbverläufe: Werte (ein Gradient professional Parameter)
- Optimiererzustände (Adam): Werte (erstes Second und zweites Second für jeden Parameter)
- Aktivierungen: Zwischenausgaben, die während des Vorwärtsdurchlaufs zur Verwendung im Rückwärtsdurchlauf gespeichert werden
Die ersten drei skalieren mit Modellgröße und sind GPU-übergreifend redundant im DDP. Die Aktivierungen skalieren mit der Batch-Größe, der Sequenzlänge und der Anzahl der Neuronen Einzigartig professional GPU da jede GPU unterschiedliche Daten verarbeitet. ZeRO berührt den Aktivierungsspeicher nicht.
Berechnen wir die Speichernutzung für ein 7B-Parameter-Modell mit Adam und FP32:
- Parameter: 7 Milliarden * 4 Bytes = 28 GB
- Farbverläufe: 7 Milliarden * 4 Bytes = 28 GB
- Der Optimierer gibt an: 7 Milliarden * 2 * 4 Bytes = 56 GB
- Speicher professional GPU in DDP: 112 GB
Aktivierungen fügen darüber hinaus erheblichen Speicher hinzu, aber da sie professional GPU einzigartig sind, kann ZeRO sie nicht partitionieren. Techniken wie Aktivierungs-Checkpointing kann helfen, es verwirft einige Aktivierungen und berechnet sie dann nach Bedarf während des Rückwärtsdurchlaufs neu. Aber das würde den Rahmen dieses Artikels sprengen.
Lassen Sie uns verstehen, wie ZeRO funktioniert, indem wir es von Grund auf implementieren, beginnend mit ZeRO-1 und uns bis hin zu ZeRO-3 vorarbeiten.
ZeRO-1: Optimierer-Zustandspartitionierung
In ZeRO-1 nur die Optimiererzustände sind partitioniert. Jede GPU:
- Hält immer noch vollständige Modellparameter und -verläufe
- Nur Geschäfte 1/N der Optimiererzustände (N = Anzahl der GPUs)
- Aktualisiert nur das entsprechende 1/N der Parameter
Dies ist die Reihenfolge der während des Trainings durchgeführten Aktionen:
- Vorwärtspass: Jede GPU verarbeitet ihren eigenen Mikrobatch
- Rückwärtspass: Berechnen Sie Farbverläufe
all-reduceSteigungen: Jede GPU erhält alle Farbverläufe- Optimierungsschritt: Jede GPU aktualisiert ihre Parameterpartition
all-gatherParameter: Synchronisieren Sie das aktualisierte Modell über GPUs hinweg

Hier ist eine vereinfachte Implementierung:
import torch
import torch.distributed as dist
class ZeRO_1:
def __init__(self, mannequin, optimizer_cls):
self.mannequin = mannequin
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = listing() # every rank holds solely its shard of the optimizer states
self.param_metadata = listing() # metadata to reconstruct shards
for param in self.mannequin.parameters():
original_shape = param.information.form
flat = param.information.view(-1)
numel = flat.numel()
the rest = numel % self.world_size
pad_size = (self.world_size - the rest) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat((flat, flat.new_zeros(pad_size)))
else:
flat_padded = flat
shard = flat_padded(shard_start:shard_end).clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.mannequin(inputs) # ahead
loss = loss_fn(output, targets) # compute loss
loss.backward() # backward
self._sync_gradients() # all-reduce gradients throughout GPUs
self.optimizer.step() # replace native shard of parameters
self._sync_params() # all collect mannequin params
# clear gradients for the subsequent step
for param in self.mannequin.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.mannequin.parameters()):
meta = self.param_metadata(idx)
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= self.world_size
self.param_shards(idx).grad = param.grad.view(-1)(meta("shard_start"):meta("shard_end"))
def _sync_params(self):
for idx, param in enumerate(self.mannequin.parameters()):
meta = self.param_metadata(idx)
full_flat = torch.empty(meta("padded_numel"), gadget=param.gadget, dtype=param.dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards(idx).information,
)
reconstructed = full_flat(:meta("numel")).view(meta("original_shape"))
param.information.copy_(reconstructed)
Beachten Sie, dass die all-reduce synchronisiert alle Farbverläufe, aber jede GPU verwendet die Farbverläufe nur für ihre eigene Parameterpartition, es kommt zu einer Überkommunikation. ZeRO-2 behebt dieses Drawback, indem es auch die Farbverläufe aufteilt.
In der Praxis würden Sie ZeRO-1 niemals verwenden, da Sie mit ZeRO-2 bei im Wesentlichen gleichen Kosten bessere Speichereinsparungen erzielen. Aber es lohnt sich trotzdem, es zu Lernzwecken noch einmal durchzugehen.
Speicher mit ZeRO-1, 7B-Modell, 8 GPUs:
- Parameter: 28 GB (vollständig repliziert)
- Farbverläufe: 28 GB (vollständig repliziert)
- Optimierer gibt an: 56 GB / 8 = 7 GB
- Insgesamt professional GPU: 63 GB (von GB)
ZeRO-2: Gradientenpartitionierung
ZeRO-2 partitioniert beide Optimiererzustände und Farbverläufe. Da jede GPU nur eine Partition von Parametern aktualisiert, benötigt sie nur die entsprechenden Farbverläufe.
ZeRO-1 verwendet all-reducewas jeder GPU alle Farbverläufe verleiht. ZeRO-2 ersetzt dies durch reduce-scattererhält jede GPU nur die Farbverläufe, die sie tatsächlich benötigt. Dies spart sowohl Speicher als auch Kommunikationsbandbreite.
Trainingsschritte:
- Vorwärtspass: Jede GPU verarbeitet ihren eigenen Mikrobatch
- Rückwärtspass: Berechnen Sie Farbverläufe
reduce-scatterSteigungen: Jede GPU erhält nur ihre Partition- Optimierungsschritt: Jede GPU aktualisiert ihre Parameterpartition
all-gatherParameter: Synchronisieren Sie das aktualisierte Modell über GPUs hinweg

Die Implementierung ist ZeRO-1 sehr ähnlich, verwendet jedoch den Gradientensynchronisationsschritt reduce-scatter anstatt all-reduce:
Aber Second, wenn jede GPU während des Backprops alle Farbverläufe berechnet, wie spart das dann tatsächlich VRAM? So geht’s:
- Da die Parameterverläufe Schicht für Schicht berechnet werden, sind sie sofort sichtbar
reduce-scatteredund die lokale Kopie wird freigegeben (unsere vereinfachte Implementierung führt dies nicht durch). - Beim Backprop benötigen Sie nur den Gradienten der nächsten Neuronenaktivierung, um den Gradienten des aktuellen Parameters zu berechnen, dh Sie benötigen nicht das gesamte Gradientendiagramm.
- Auf diese Weise können Sie den Speicher für Farbverläufe freigeben, während Sie sich rückwärts bewegen, wobei nur die zugewiesene Partition für jede GPU erhalten bleibt.
Speicher mit ZeRO-2, 7B-Modell, 8 GPUs:
- Parameter: 28 GB (vollständig repliziert)
- Farbverläufe: 28 GB / 8 = 3,5 GB
- Optimierer gibt an: 56 GB / 8 = 7 GB
- Gesamt professional GPU: 38,5 GB (vorher 112 GB)
ZeRO-3: Parameterpartitionierung
ZeRO-3 partitioniert Optimiererzustände, -verläufe usw Parameter. Jede GPU speichert nur 1/N des gesamten Modellstatus.
Bei Vorwärts- und Rückwärtsdurchläufen benötigt jede Schicht ihre vollständigen Parameter, aber jede GPU speichert nur einen Bruchteil. Additionally wir Alle Parameter Simply-in-Time erfassenverwenden Sie sie und entsorgen Sie sie sofort danach.
Trainingsschritte:
- Vorwärtspass:
- Sammeln Sie alle Parameter der Ebene von allen GPUs
- Führen Sie den Vorwärtsdurchlauf der Ebene aus und verwenden Sie dabei die Aktivierungen der vorherigen Ebene als Eingabe
- Verwerfen Sie die gesammelten Parameter (behalten Sie nur die lokale Partition bei)
- Wiederholen Sie diese Schritte, bis alle Schichten fertig sind
- Rückwärtsdurchlauf (professional Schicht, rückwärts):
- Erfassen Sie erneut alle Parameter der Ebene
- Berechnen Sie die Farbverläufe für die aktuelle Ebene mithilfe der Aktivierungsgradienten der nächsten Ebene
- Reduzieren Sie die Streuung der Farbverläufe (jede GPU behält ihren Shard)
- Verwerfen Sie die gesammelten Parameter (behalten Sie nur die lokale Partition bei)
- Wiederholen Sie diese Schritte, bis alle Schichten fertig sind
- Jede GPU führt einen Optimierungsschritt auf ihrer Partition aus
- Es ist keine abschließende Gesamterfassung erforderlich, da die Parameter während des Vorwärtsdurchlaufs Schicht für Schicht erfasst werden

Hier ist eine vereinfachte Implementierung:
class ZeRO_3(ZeRO_2):
"""
ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + mannequin parameters (stage 3).
At relaxation, every rank holds solely param_shards(idx) — a 1/world_size slice
of every parameter. Full parameters are materialised briefly throughout
the ahead and backward passes by way of all_gather, then instantly freed.
"""
def __init__(self, mannequin, optimizer_cls):
self.mannequin = mannequin
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_metadata = ()
shard_list = ()
self._param_to_idx = {}
for idx, param in enumerate(self.mannequin.parameters()):
original_shape = param.information.form
flat = param.information.view(-1)
numel = flat.numel()
the rest = numel % self.world_size
pad_size = (self.world_size - the rest) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat((flat, flat.new_zeros(pad_size)))
else:
flat_padded = flat
shard = flat_padded(shard_start:shard_end).clone()
shard_list.append(shard)
# Substitute the total tensor with solely this rank's shard.
# The mannequin's param.information now factors to a tiny slice; the total
# weight will likely be reconstructed on demand throughout ahead/backward.
param.information = shard.detach()
self._param_to_idx(param) = idx
self.param_shards = (s.requires_grad_(True) for s in shard_list)
self.optimizer = optimizer_cls(self.param_shards)
self._register_hooks()
def _gather_param(self, idx, gadget, dtype):
"""All-gather the total parameter tensor for parameter `idx`."""
meta = self.param_metadata(idx)
full_flat = torch.empty(meta("padded_numel"), gadget=gadget, dtype=dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards(idx).information,
)
return full_flat(: meta("numel")).view(meta("original_shape"))
def _gather_module_params(self, module):
"""Collect full params for each parameter that belongs to this module solely (not kids)."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx(param)
param.information = self._gather_param(idx, param.gadget, param.dtype)
def _reshard_module_params(self, module):
"""Reshard params again to native shard for each direct param of this module."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx(param)
param.information = self.param_shards(idx).information
def _register_hooks(self):
self._hooks = ()
for module in self.mannequin.modules():
# Skip container modules that haven't any direct parameters
if not listing(module.parameters(recurse=False)):
proceed
# Ahead: collect -> run -> reshard
h1 = module.register_forward_pre_hook(
lambda mod, _inputs: self._gather_module_params(mod)
)
h2 = module.register_forward_hook(
lambda mod, _inputs, _output: self._reshard_module_params(mod)
)
# Backward: collect earlier than grad computation → reshard after
h3 = module.register_full_backward_pre_hook(
lambda mod, _grad_output: self._gather_module_params(mod)
)
h4 = module.register_full_backward_hook(
lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
)
self._hooks.lengthen((h1, h2, h3, h4))
def training_step(self, inputs, targets, loss_fn):
# Hooks deal with all collect/reshard round every module robotically
output = self.mannequin(inputs)
loss = loss_fn(output, targets)
loss.backward()
self._sync_gradients()
# Every rank updates solely its native shard
self.optimizer.step()
for param in self.mannequin.parameters():
param.grad = None
Die Parameter jeder Ebene werden unmittelbar vor ihrem Bedarf erfasst und unmittelbar danach freigegeben. Dadurch wird der Spitzenspeicher auf ein Minimal reduziert, allerdings auf Kosten einer höheren Kommunikation. In der Praxis überlappen Implementierungen das All-Collect für Layer N+1 mit dem Ahead von Layer N, um die Latenz zu verbergen.
Speicher mit ZeRO-3, 7B-Modell, 8 GPUs:
- Parameter: 28 GB / 8 = 3,5 GB
- Farbverläufe: 28 GB / 8 = 3,5 GB
- Optimierer gibt an: 56 GB / 8 = 7 GB
- Insgesamt professional GPU: 14 GB (vorher 112 GB)
Das ist ein 8-fache Reduzierung bei der Speichernutzung, was genau das ist, was wir von einer Partitionierung auf 8 GPUs erwarten würden.
Verwendung von ZeRO in PyTorch
PyTorch wird mit zwei Implementierungen von ZeRO-3 ausgeliefert: FSDP1 (älter, weniger optimiert) und FSDP2 (neuer, empfohlen). Verwenden Sie immer FSDP2.
FSDP (Absolutely Sharded Knowledge Parallel) übernimmt automatisch die Parametererfassung, Gradientenstreuung, Kommunikationsüberlappung und Speicherverwaltung:
from torch.distributed.fsdp import fully_shard
mannequin = Transformer()
for layer in mannequin.layers:
fully_shard(layer)
fully_shard(mannequin)
Sie müssen sich bewerben fully_shard Schicht für Schicht und wickeln Sie dann das gesamte Modell ein.
Abschluss
ZeRO tauscht Erinnerung gegen Kommunikation aus, es ist additionally kein kostenloses Mittagessen. Im Allgemeinen lohnt es sich für kleinere Modelle (z. B. BERT) nicht, aber für größere Modelle ist es ein Recreation Changer.
Herzlichen Glückwunsch, dass Sie es bis zum Ende geschafft haben! In diesem Beitrag haben Sie Folgendes erfahren:
- Das Speicherredundanzproblem im Commonplace-DDP
- Wie ZeRO Optimiererzustände, Verläufe und Parameter auf GPUs verteilt
- Die drei Phasen von ZeRO und ihre Kompromisse zwischen Gedächtnis und Kommunikation
- So verwenden Sie ZeRO-3 über PyTorchs FSDP
Im nächsten Artikel befassen wir uns mit der Tensor-Parallelität, einer Modellparallelitätstechnik, die die Berechnung einer Ebene beschleunigt, indem sie die Arbeit auf GPUs verteilt.
