einer Serie über verteilte KI über mehrere GPUs:

Einführung

Im vorherigen Beitrag haben wir gesehen, wie Distributed Knowledge Parallelism (DDP) das Coaching beschleunigt, indem es Stapel auf GPUs aufteilt. DDP löst das Durchsatzproblem, bringt jedoch eine neue Herausforderung mit sich: Speicherredundanz.

In Vanilla DDP enthält jede GPU eine vollständige Kopie der Modellparameter, Verläufe und Optimiererzustände. Bei großen Modellen wie GPT-3 (175B-Parameter) führt diese Redundanz zu einer großen Verschwendung wertvollen VRAMs.

Bild des Autors: Modell, Farbverläufe und Optimierer sind im regulären DDP auf allen GPUs redundant

ZeRO (Zero Redundancy Optimizer) löst dieses Drawback. Es gibt drei Ebenen:

  • ZeRO-1 Partitionen nur Optimiererzustände
  • ZeRO-2 Partitioniert Optimiererzustände + Farbverläufe
  • ZeRO-3 Partitioniert Optimiererzustände + Verläufe + Modellparameter

ZeRO ist keine Parallelitätstechnik, da alle GPUs immer noch die gleichen Vorwärts- und Rückwärtsdurchläufe ausführen. Es ist ein Speicheroptimierung Strategie, die Redundanz zwischen GPUs eliminiert und es Ihnen ermöglicht, größere Modelle auf derselben {Hardware} zu trainieren.

Das Speicherproblem in DDP

Lassen Sie uns aufschlüsseln, was während des Trainings tatsächlich Speicher verbraucht. Für ein Modell mit Parametern:

  • Modellparameter: Werte (die Gewichte Ihres neuronalen Netzwerks)
  • Farbverläufe: Werte (ein Gradient professional Parameter)
  • Optimiererzustände (Adam): Werte (erstes Second und zweites Second für jeden Parameter)
  • Aktivierungen: Zwischenausgaben, die während des Vorwärtsdurchlaufs zur Verwendung im Rückwärtsdurchlauf gespeichert werden

Die ersten drei skalieren mit Modellgröße und sind GPU-übergreifend redundant im DDP. Die Aktivierungen skalieren mit der Batch-Größe, der Sequenzlänge und der Anzahl der Neuronen Einzigartig professional GPU da jede GPU unterschiedliche Daten verarbeitet. ZeRO berührt den Aktivierungsspeicher nicht.

Berechnen wir die Speichernutzung für ein 7B-Parameter-Modell mit Adam und FP32:

  • Parameter: 7 Milliarden * 4 Bytes = 28 GB
  • Farbverläufe: 7 Milliarden * 4 Bytes = 28 GB
  • Der Optimierer gibt an: 7 Milliarden * 2 * 4 Bytes = 56 GB
  • Speicher professional GPU in DDP: 112 GB

Aktivierungen fügen darüber hinaus erheblichen Speicher hinzu, aber da sie professional GPU einzigartig sind, kann ZeRO sie nicht partitionieren. Techniken wie Aktivierungs-Checkpointing kann helfen, es verwirft einige Aktivierungen und berechnet sie dann nach Bedarf während des Rückwärtsdurchlaufs neu. Aber das würde den Rahmen dieses Artikels sprengen.

Lassen Sie uns verstehen, wie ZeRO funktioniert, indem wir es von Grund auf implementieren, beginnend mit ZeRO-1 und uns bis hin zu ZeRO-3 vorarbeiten.

ZeRO-1: Optimierer-Zustandspartitionierung

In ZeRO-1 nur die Optimiererzustände sind partitioniert. Jede GPU:

  • Hält immer noch vollständige Modellparameter und -verläufe
  • Nur Geschäfte 1/N der Optimiererzustände (N = Anzahl der GPUs)
  • Aktualisiert nur das entsprechende 1/N der Parameter

Dies ist die Reihenfolge der während des Trainings durchgeführten Aktionen:

  1. Vorwärtspass: Jede GPU verarbeitet ihren eigenen Mikrobatch
  2. Rückwärtspass: Berechnen Sie Farbverläufe
  3. all-reduce Steigungen: Jede GPU erhält alle Farbverläufe
  4. Optimierungsschritt: Jede GPU aktualisiert ihre Parameterpartition
  5. all-gather Parameter: Synchronisieren Sie das aktualisierte Modell über GPUs hinweg
Bild vom Autor: Zero 1-Animation

Hier ist eine vereinfachte Implementierung:

import torch
import torch.distributed as dist


class ZeRO_1:
    def __init__(self, mannequin, optimizer_cls):
        self.mannequin = mannequin
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

        self.param_shards = listing()  # every rank holds solely its shard of the optimizer states
        self.param_metadata = listing()  # metadata to reconstruct shards

        for param in self.mannequin.parameters():
            original_shape = param.information.form
            flat = param.information.view(-1)
            numel = flat.numel()

            the rest = numel % self.world_size
            pad_size = (self.world_size - the rest) % self.world_size
            padded_numel = numel + pad_size
            shard_size = padded_numel // self.world_size

            shard_start = self.rank * shard_size
            shard_end = shard_start + shard_size

            self.param_metadata.append(
                {
                    "original_shape": original_shape,
                    "numel": numel,
                    "padded_numel": padded_numel,
                    "shard_size": shard_size,
                    "shard_start": shard_start,
                    "shard_end": shard_end,
                }
            )

            if pad_size > 0:
                flat_padded = torch.cat((flat, flat.new_zeros(pad_size)))
            else:
                flat_padded = flat

            shard = flat_padded(shard_start:shard_end).clone()
            shard.requires_grad_(True)
            self.param_shards.append(shard)

        self.optimizer = optimizer_cls(self.param_shards)

    def training_step(self, inputs, targets, loss_fn):
        output = self.mannequin(inputs) # ahead
        loss = loss_fn(output, targets) # compute loss
        loss.backward() # backward

        self._sync_gradients()  # all-reduce gradients throughout GPUs
        self.optimizer.step() # replace native shard of parameters
        self._sync_params() # all collect mannequin params

        # clear gradients for the subsequent step
        for param in self.mannequin.parameters():
            param.grad = None

    def _sync_gradients(self):
        for idx, param in enumerate(self.mannequin.parameters()):
            meta = self.param_metadata(idx)

            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
            param.grad /= self.world_size

            self.param_shards(idx).grad = param.grad.view(-1)(meta("shard_start"):meta("shard_end"))

    def _sync_params(self):
        for idx, param in enumerate(self.mannequin.parameters()):
            meta = self.param_metadata(idx)

            full_flat = torch.empty(meta("padded_numel"), gadget=param.gadget, dtype=param.dtype)
            dist.all_gather_into_tensor(
                output_tensor=full_flat,
                input_tensor=self.param_shards(idx).information,
            )
            
            reconstructed = full_flat(:meta("numel")).view(meta("original_shape"))
            param.information.copy_(reconstructed)

Beachten Sie, dass die all-reduce synchronisiert alle Farbverläufe, aber jede GPU verwendet die Farbverläufe nur für ihre eigene Parameterpartition, es kommt zu einer Überkommunikation. ZeRO-2 behebt dieses Drawback, indem es auch die Farbverläufe aufteilt.

In der Praxis würden Sie ZeRO-1 niemals verwenden, da Sie mit ZeRO-2 bei im Wesentlichen gleichen Kosten bessere Speichereinsparungen erzielen. Aber es lohnt sich trotzdem, es zu Lernzwecken noch einmal durchzugehen.

Speicher mit ZeRO-1, 7B-Modell, 8 GPUs:

  • Parameter: 28 GB (vollständig repliziert)
  • Farbverläufe: 28 GB (vollständig repliziert)
  • Optimierer gibt an: 56 GB / 8 = 7 GB
  • Insgesamt professional GPU: 63 GB (von GB)

ZeRO-2: Gradientenpartitionierung

ZeRO-2 partitioniert beide Optimiererzustände und Farbverläufe. Da jede GPU nur eine Partition von Parametern aktualisiert, benötigt sie nur die entsprechenden Farbverläufe.

ZeRO-1 verwendet all-reducewas jeder GPU alle Farbverläufe verleiht. ZeRO-2 ersetzt dies durch reduce-scattererhält jede GPU nur die Farbverläufe, die sie tatsächlich benötigt. Dies spart sowohl Speicher als auch Kommunikationsbandbreite.

Trainingsschritte:

  1. Vorwärtspass: Jede GPU verarbeitet ihren eigenen Mikrobatch
  2. Rückwärtspass: Berechnen Sie Farbverläufe
  3. reduce-scatter Steigungen: Jede GPU erhält nur ihre Partition
  4. Optimierungsschritt: Jede GPU aktualisiert ihre Parameterpartition
  5. all-gather Parameter: Synchronisieren Sie das aktualisierte Modell über GPUs hinweg
Bild vom Autor: Zero 2-Animation

Die Implementierung ist ZeRO-1 sehr ähnlich, verwendet jedoch den Gradientensynchronisationsschritt reduce-scatter anstatt all-reduce:
Aber Second, wenn jede GPU während des Backprops alle Farbverläufe berechnet, wie spart das dann tatsächlich VRAM? So geht’s:

  • Da die Parameterverläufe Schicht für Schicht berechnet werden, sind sie sofort sichtbar reduce-scattered und die lokale Kopie wird freigegeben (unsere vereinfachte Implementierung führt dies nicht durch).
  • Beim Backprop benötigen Sie nur den Gradienten der nächsten Neuronenaktivierung, um den Gradienten des aktuellen Parameters zu berechnen, dh Sie benötigen nicht das gesamte Gradientendiagramm.
  • Auf diese Weise können Sie den Speicher für Farbverläufe freigeben, während Sie sich rückwärts bewegen, wobei nur die zugewiesene Partition für jede GPU erhalten bleibt.

Speicher mit ZeRO-2, 7B-Modell, 8 GPUs:

  • Parameter: 28 GB (vollständig repliziert)
  • Farbverläufe: 28 GB / 8 = 3,5 GB
  • Optimierer gibt an: 56 GB / 8 = 7 GB
  • Gesamt professional GPU: 38,5 GB (vorher 112 GB)

ZeRO-3: Parameterpartitionierung

ZeRO-3 partitioniert Optimiererzustände, -verläufe usw Parameter. Jede GPU speichert nur 1/N des gesamten Modellstatus.

Bei Vorwärts- und Rückwärtsdurchläufen benötigt jede Schicht ihre vollständigen Parameter, aber jede GPU speichert nur einen Bruchteil. Additionally wir Alle Parameter Simply-in-Time erfassenverwenden Sie sie und entsorgen Sie sie sofort danach.

Trainingsschritte:

  • Vorwärtspass:
    • Sammeln Sie alle Parameter der Ebene von allen GPUs
    • Führen Sie den Vorwärtsdurchlauf der Ebene aus und verwenden Sie dabei die Aktivierungen der vorherigen Ebene als Eingabe
    • Verwerfen Sie die gesammelten Parameter (behalten Sie nur die lokale Partition bei)
    • Wiederholen Sie diese Schritte, bis alle Schichten fertig sind
  • Rückwärtsdurchlauf (professional Schicht, rückwärts):
    • Erfassen Sie erneut alle Parameter der Ebene
    • Berechnen Sie die Farbverläufe für die aktuelle Ebene mithilfe der Aktivierungsgradienten der nächsten Ebene
    • Reduzieren Sie die Streuung der Farbverläufe (jede GPU behält ihren Shard)
    • Verwerfen Sie die gesammelten Parameter (behalten Sie nur die lokale Partition bei)
    • Wiederholen Sie diese Schritte, bis alle Schichten fertig sind
  • Jede GPU führt einen Optimierungsschritt auf ihrer Partition aus
  • Es ist keine abschließende Gesamterfassung erforderlich, da die Parameter während des Vorwärtsdurchlaufs Schicht für Schicht erfasst werden
Bild vom Autor: Zero 3-Animation

Hier ist eine vereinfachte Implementierung:

class ZeRO_3(ZeRO_2):
    """
    ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + mannequin parameters (stage 3).

    At relaxation, every rank holds solely param_shards(idx) — a 1/world_size slice
    of every parameter. Full parameters are materialised briefly throughout
    the ahead and backward passes by way of all_gather, then instantly freed.
    """

    def __init__(self, mannequin, optimizer_cls):
        self.mannequin = mannequin
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

        self.param_metadata = ()
        shard_list = ()

        self._param_to_idx = {}

        for idx, param in enumerate(self.mannequin.parameters()):
            original_shape = param.information.form
            flat = param.information.view(-1)
            numel = flat.numel()

            the rest = numel % self.world_size
            pad_size = (self.world_size - the rest) % self.world_size
            padded_numel = numel + pad_size
            shard_size = padded_numel // self.world_size

            shard_start = self.rank * shard_size
            shard_end = shard_start + shard_size

            self.param_metadata.append(
                {
                    "original_shape": original_shape,
                    "numel": numel,
                    "padded_numel": padded_numel,
                    "shard_size": shard_size,
                    "shard_start": shard_start,
                    "shard_end": shard_end,
                }
            )

            if pad_size > 0:
                flat_padded = torch.cat((flat, flat.new_zeros(pad_size)))
            else:
                flat_padded = flat

            shard = flat_padded(shard_start:shard_end).clone()
            shard_list.append(shard)

            # Substitute the total tensor with solely this rank's shard.
            # The mannequin's param.information now factors to a tiny slice; the total
            # weight will likely be reconstructed on demand throughout ahead/backward.
            param.information = shard.detach()
            self._param_to_idx(param) = idx

        self.param_shards = (s.requires_grad_(True) for s in shard_list)
        self.optimizer = optimizer_cls(self.param_shards)

        self._register_hooks()

    def _gather_param(self, idx, gadget, dtype):
        """All-gather the total parameter tensor for parameter `idx`."""
        meta = self.param_metadata(idx)
        full_flat = torch.empty(meta("padded_numel"), gadget=gadget, dtype=dtype)
        dist.all_gather_into_tensor(
            output_tensor=full_flat,
            input_tensor=self.param_shards(idx).information,
        )
        return full_flat(: meta("numel")).view(meta("original_shape"))

    def _gather_module_params(self, module):
        """Collect full params for each parameter that belongs to this module solely (not kids)."""
        for param in module.parameters(recurse=False):
            idx = self._param_to_idx(param)
            param.information = self._gather_param(idx, param.gadget, param.dtype)

    def _reshard_module_params(self, module):
        """Reshard params again to native shard for each direct param of this module."""
        for param in module.parameters(recurse=False):
            idx = self._param_to_idx(param)
            param.information = self.param_shards(idx).information

    def _register_hooks(self):
        self._hooks = ()
        for module in self.mannequin.modules():
            # Skip container modules that haven't any direct parameters
            if not listing(module.parameters(recurse=False)):
                proceed

            # Ahead: collect -> run -> reshard
            h1 = module.register_forward_pre_hook(
                lambda mod, _inputs: self._gather_module_params(mod)
            )
            h2 = module.register_forward_hook(
                lambda mod, _inputs, _output: self._reshard_module_params(mod)
            )

            # Backward: collect earlier than grad computation → reshard after
            h3 = module.register_full_backward_pre_hook(
                lambda mod, _grad_output: self._gather_module_params(mod)
            )
            h4 = module.register_full_backward_hook(
                lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
            )

            self._hooks.lengthen((h1, h2, h3, h4))

    def training_step(self, inputs, targets, loss_fn):
        # Hooks deal with all collect/reshard round every module robotically
        output = self.mannequin(inputs)
        loss = loss_fn(output, targets)
        loss.backward()

        self._sync_gradients()

        # Every rank updates solely its native shard
        self.optimizer.step()

        for param in self.mannequin.parameters():
            param.grad = None

Die Parameter jeder Ebene werden unmittelbar vor ihrem Bedarf erfasst und unmittelbar danach freigegeben. Dadurch wird der Spitzenspeicher auf ein Minimal reduziert, allerdings auf Kosten einer höheren Kommunikation. In der Praxis überlappen Implementierungen das All-Collect für Layer N+1 mit dem Ahead von Layer N, um die Latenz zu verbergen.

Speicher mit ZeRO-3, 7B-Modell, 8 GPUs:

  • Parameter: 28 GB / 8 = 3,5 GB
  • Farbverläufe: 28 GB / 8 = 3,5 GB
  • Optimierer gibt an: 56 GB / 8 = 7 GB
  • Insgesamt professional GPU: 14 GB (vorher 112 GB)

Das ist ein 8-fache Reduzierung bei der Speichernutzung, was genau das ist, was wir von einer Partitionierung auf 8 GPUs erwarten würden.

Verwendung von ZeRO in PyTorch

PyTorch wird mit zwei Implementierungen von ZeRO-3 ausgeliefert: FSDP1 (älter, weniger optimiert) und FSDP2 (neuer, empfohlen). Verwenden Sie immer FSDP2.

FSDP (Absolutely Sharded Knowledge Parallel) übernimmt automatisch die Parametererfassung, Gradientenstreuung, Kommunikationsüberlappung und Speicherverwaltung:

from torch.distributed.fsdp import fully_shard

mannequin = Transformer()
for layer in mannequin.layers:
    fully_shard(layer)
fully_shard(mannequin)

Sie müssen sich bewerben fully_shard Schicht für Schicht und wickeln Sie dann das gesamte Modell ein.

Abschluss

ZeRO tauscht Erinnerung gegen Kommunikation aus, es ist additionally kein kostenloses Mittagessen. Im Allgemeinen lohnt es sich für kleinere Modelle (z. B. BERT) nicht, aber für größere Modelle ist es ein Recreation Changer.

Herzlichen Glückwunsch, dass Sie es bis zum Ende geschafft haben! In diesem Beitrag haben Sie Folgendes erfahren:

  • Das Speicherredundanzproblem im Commonplace-DDP
  • Wie ZeRO Optimiererzustände, Verläufe und Parameter auf GPUs verteilt
  • Die drei Phasen von ZeRO und ihre Kompromisse zwischen Gedächtnis und Kommunikation
  • So verwenden Sie ZeRO-3 über PyTorchs FSDP

Im nächsten Artikel befassen wir uns mit der Tensor-Parallelität, einer Modellparallelitätstechnik, die die Berechnung einer Ebene beschleunigt, indem sie die Arbeit auf GPUs verteilt.

Referenzen

  1. ZeRO: Speicheroptimierungen für das Coaching von Billionen-Parametermodellen (Originalpapier)
  2. PyTorch FSDP-Tutorial
  3. FSDP-API-Referenz
  4. Das Extremely-Scale-Playbook von Huggging Face

Von admin

Schreibe einen Kommentar

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert