Inspiriert von Andrej Kapathys jüngstem Youtube-Video über Lassen Sie uns GPT-2 (124 M) reproduzieren.ich möchte es mit den meisten Trainingsoptimierungen in Jax neu aufbauen. Jax ist für hocheffiziente Rechengeschwindigkeiten ausgelegt, und es ist ziemlich interessant, Pytorch mit seiner jüngsten Trainingsoptimierung und Jax mit seinen zugehörigen Bibliotheken wie Flax (Layers API für neuronales Netzwerktraining für Jax) und Optax (eine Gradientenverarbeitungs- und Optimierungsbibliothek für JAX) zu vergleichen. Wir werden schnell lernen, was Jax ist, und das GPT mit Jax neu aufbauen. Am Ende werden wir die Token/Sekunde mit MultiGPU-Coaching zwischen Pytorch und Jax vergleichen!
Was ist Jax?
Basierend auf seiner lesenthedocJAX ist eine Python-Bibliothek für beschleunigerorientierte Array-Berechnung und Programmtransformation, die für hochleistungsfähige numerische Berechnungen und maschinelles Lernen im großen Maßstab entwickelt wurde. Ich möchte JAX mit seinem Namen vorstellen. Während manche es Simply One other nennen XLA (Accelerated Linear Algibra), ich nenne es lieber J(it) A(utograd) X(LA), um seine Fähigkeit zur hohen Effizienz zu demonstrieren.
J – Simply-in-time (JIT)-Kompilierung. Wenn Sie Ihre Python-Funktion ausführen, konvertiert Jax sie in einen primitiven Operationssatz namens Jaxpr. Anschließend wird der Jaxpr-Ausdruck in eine Eingabe für XLA umgewandelt, das die Skripte auf niedrigerer Ebene kompiliert, um ein optimiertes ausführbares Programm für das Zielgerät (CPU, GPU oder TPU) zu erstellen.
A — Autograd. Die Berechnung von Gradienten ist ein wichtiger Teil moderner Methoden des maschinellen Lernens. Sie können einfach jax.grad()
um Gradienten zu erhalten, mit denen Sie die Modelle optimieren können.
X — XLA. Dies ist ein Open-Supply-Compiler für maschinelles Lernen für CPU-, GPU- und ML-Beschleuniger. Im Allgemeinen führt XLA mehrere integrierte Optimierungs- und Analysedurchläufe für die StabilHLO Graph, sendet dann die HLO-Berechnung an ein Backend für weitere Optimierungen auf HLO-Ebene. Das Backend führt dann eine zielspezifische Codegenerierung durch.
Dies sind nur einige der Hauptfunktionen von JAX, aber es gibt auch viele benutzerfreundliche numpy-ähnliche APIs in jax.numpy
und automatische Vektorisierung mit jax.vmap
und parallelisieren Sie Ihre Codes auf mehreren Geräten über jax.pmap
. Wir werden in weiteren Blogs weitere Jax-Konzepte und -Anwendungen behandeln, aber jetzt wollen wir den NanoGPT mit Jax reproduzieren!
Von der Aufmerksamkeit zum Transformator
GPT ist ein Transformer-Modell nur für Decoder, und der wichtigste Baustein ist das Consideration-Modul. Wir können zunächst eine Modellkonfigurationsdatenklasse definieren, um die Modellhyperparameter des Modells zu speichern, damit das Modellmodul sie effizient nutzen kann, um die Modellarchitektur zu initialisieren. Ähnlich wie beim 124M-GPT-Modell initialisieren wir hier einen 12-Schicht-Transformer-Decoder mit 12 Köpfen und einer Vokabelgröße von 50257 Token, von denen jeder 768 Einbettungsdimensionen hat. Die Blockgröße für die Consideration-Berechnung beträgt 1024.
from dataclasses import dataclass@dataclass
class ModelConfig:
vocab_size: int = 50257
n_head: int = 12
n_embd: int = 768
block_size: int = 1024
n_layer: int = 12
dropout_rate: float = 0.1
Als nächstes kommen wir zum wichtigsten Baustein des Transformer-Modells – Aufmerksamkeit. Die Idee besteht darin, die Eingaben in drei Gewichtungsmatrizen zu verarbeiten: Schlüssel, Abfrage und Wert. Hier verlassen wir uns auf die flax
eine Jax Layer- und Trainings-API-Bibliothek zum Initialisieren der 3-Gewichtsmatrix, indem Sie einfach die flax.linen.Dense
. Wie erwähnt hat Jax viele numpy-ähnliche APIs, daher formen wir die Ausgaben nach der Gewichtsmatrix um mit jax.numpy.reshape
von (batch_size, sequence_length, embedding_dim) bis (batch_size, sequence_length, num_head, embedding_dim / num_head). Da wir eine Matrixmultiplikation auf den Schlüssel- und Wertmatrizen durchführen müssen, hat jax auch jax.numpy.matmul
API und jax.numpy.transpose
(Transponieren Sie die Schlüsselmatrix für die Multiplikation).
Beachten Sie, dass wir eine Maske auf die Aufmerksamkeitsmatrix setzen müssen, um Informationslecks zu vermeiden (verhindern, dass die vorherigen Token Zugriff auf die späteren Token haben). jax.numpy.tril
hilft beim Aufbau eines unteren Dreiecksarrays und jax.numpy.the place
kann die unendliche Zahl für uns füllen, um 0 nach Softmax zu erhalten jax.nn.softmax
. Die vollständigen Codes der Multihead-Aufmerksamkeit finden Sie unten.
from flax import linen as nn
import jax.numpy as jnpclass CausalSelfAttention(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
assert len(x.form) == 3
b, l, d = x.form
q = nn.Dense(self.config.n_embd)(x)
ok = nn.Dense(self.config.n_embd)(x)
v = nn.Dense(self.config.n_embd)(x)
# q*ok / sqrt(dim) -> softmax -> @v
q = jnp.reshape(q, (b, l, d//self.config.n_head , self.config.n_head))
ok = jnp.reshape(ok, (b, l, d//self.config.n_head , self.config.n_head))
v = jnp.reshape(v, (b, l, d//self.config.n_head , self.config.n_head))
norm = jnp.sqrt(checklist(jnp.form(ok))(-1))
attn = jnp.matmul(q,jnp.transpose(ok, (0,1,3,2))) / norm
masks = jnp.tril(attn)
attn = jnp.the place(masks(:,:,:l,:l), attn, float("-inf"))
probs = jax.nn.softmax(attn, axis=-1)
y = jnp.matmul(probs, v)
y = jnp.reshape(y, (b,l,d))
y = nn.Dense(self.config.n_embd)(y)
return y
Möglicherweise stellen Sie fest, dass es keine __init__
oder ahead
Methoden, wie Sie in pytorch sehen können. Dies ist das Besondere an jax, wo Sie die Ebenen explizit definieren können mit setup
Methoden oder definieren Sie sie implizit im Vorwärtsdurchlauf durch Hinzufügen nn.compact
auf __call__
Methode. (Referenz)
Als nächstes erstellen wir die MLP- und Block-Ebene, die die Dense-Ebene, die Gelu-Aktivierungsfunktion, LayerNorm und Dropout umfasst. Auch hier verfügt flax.linen über die Layer-APIs, die uns beim Erstellen des Moduls helfen. Beachten Sie, dass wir ein deterministic
Boolesche Variable zur Steuerung unterschiedlicher Verhaltensweisen während des Trainings oder der Auswertung für einige Ebenen wie Dropout.
class MLP(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(self.config.n_embd*4)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dropout(price=self.config.dropout_rate)(x, deterministic=deterministic)
x = nn.Dense(self.config.n_embd)(x)
x = nn.Dropout(price=self.config.dropout_rate)(x, deterministic=deterministic)
return x
class Block(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x):
x = nn.LayerNorm()(x)
x = x + CausalSelfAttention(self.config)(x)
x = nn.LayerNorm()(x)
x = x + MLP(self.config)(x)
return x
Verwenden wir nun die obigen Blöcke, um den NanoGPT zu erstellen:
Angesichts der Eingaben einer Sequenz von Token-IDs verwenden wir die flax.linen.Embed
Schicht, um Positionseinbettungen und Tokeneinbettungen zu erhalten. Dann übergeben wir sie N-mal an das Blockmodul, wobei N die Anzahl der in der Modellkonfiguration definierten Schichten ist. Am Ende ordnen wir die Ausgaben des letzten Blocks den Wahrscheinlichkeiten für jedes Token im Vokabular zu, um das nächste Token vorherzusagen. Neben der Vorwärts- __call__
Methode erstellen wir auch eine init
Methoden zum Abrufen der Dummy-Eingaben, um die Parameter des Modells zu erhalten.
class GPT(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=False):
B, T = x.form
assert T <= self.config.block_size
pos = jnp.arange(0, T)(None)
pos_emb = nn.Embed(self.config.block_size, self.config.n_embd)(pos)
wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
tok_emb = wte(x)
x = tok_emb + pos_emb
for _ in vary(self.config.n_layer):
x = Block(self.config)(x)
x = nn.LayerNorm()(x)
logits = nn.Dense(config.n_embd, config.vocab_size)
# logits = wte.attend(x) # parameter sharing
return logits
def init(self, rng):
tokens = jnp.zeros((1, self.config.block_size), dtype=jnp.uint16)
params = jax.jit(tremendous().init, static_argnums=(2,))(rng, tokens, True)
return params
Nun überprüfen wir die Anzahl der Parameter: Wir initialisieren zuerst die Modellkonfigurationsdatenklasse und den Zufallsschlüssel, erstellen dann Dummy-Eingaben und geben diese in das GPT-Modell ein. Dann verwenden wir die jax.util.treemap
API zum Erstellen einer Zählparameterfunktion. Wir haben 124439808 (124 M) Parameter, dieselbe Menge wie Huggingfaces GPT2, BOOM!
DataLoader und Trainingsschleife
Lassen Sie uns nun einen kleinen Datensatz überanpassen. Um ihn in Andrejs Video über Pytorch NanoGPT vergleichbar zu machen, verwenden wir das Spielzeug Datensatz die er in seinem Video geteilt hat. Wir verwenden den Tokenizer von GPT2 von tiktoken
Bibliothek, um alle Texte aus der Eingabedatei zu tokenisieren und die Tokens in jax.numpy.array
für Jax‘ Modeltraining.
class DataLoader:
def __init__(self, B, T):
self.current_position = 0
self.B = B
self.T = Twith open("enter.txt","r") as f:
textual content = f.learn()
enc = tiktoken.get_encoding("gpt2")
self.tokens = jnp.array(enc.encode(textual content))
print(f"loaded {len(self.tokens)} tokens within the datasets" )
print(f" 1 epoch = {len(self.tokens)//(B*T)} batches")
def next_batch(self):
B,T = self.B, self.T
buf = self.tokens(self.current_position:self.current_position+B*T+1)
x,y = jnp.reshape(buf(:-1),(B,T)), jnp.reshape(buf(1:),(B,T))
self.current_position += B*T
if self.current_position + B*T+1 > len(self.tokens):
self.current_position = 0
return x,y
Als nächstes vergessen wir zunächst das verteilte Coaching und die Optimierung und erstellen einfach eine naive Trainingsschleife für eine Plausibilitätsprüfung. Das erste, was wir nach der Initialisierung des Modells tun müssen, ist die Erstellung eines ZugStaatein Modellzustand, in dem wir die Parameter und Gradienten aktualisieren können. Der TrainState nimmt drei wichtige Eingaben entgegen: apply_fn (Modell-Vorwärtsfunktion), params (Modellparameter aus der Init-Methode) und tx (eine Optax-Gradiententransformation).
Dann verwenden wir die Funktion train_step, um den Modellstatus (Gradienten und Parameter) zu aktualisieren und mit dem Modelltraining fortzufahren. Optax
Bereitstellung der Softmax-Kreuzentropie als Verlustfunktion für die nächste Token-Vorhersageaufgabe und jax.value_and_grad
berechnet die Gradienten und den Verlustwert für die Verlustfunktion. Schließlich aktualisieren wir den Zustand des Modells mit den neuen Parametern mithilfe der apply_gradients
API. (Referenz) Vergessen Sie nicht, die Funktion train_step zu deaktivieren, um den Rechenaufwand zu reduzieren!
def init_train_state(key, config) -> TrainState:
mannequin = GPT(config)
params = mannequin.init(key)
optimizer = optax.adamw(3e-4, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1)
train_state = TrainState.create(
apply_fn=mannequin.apply,
params=params,
tx=optimizer)
return train_state@jax.jit
def train_step(state: TrainState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple(jnp.ndarray, TrainState):
def loss_fn(params: FrozenDict) -> jnp.ndarray:
logits = state.apply_fn(params, x, False)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).imply()
return loss
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
new_state = state.apply_gradients(grads=grads)
return loss, new_state
Jetzt ist alles bereit für die Poorman-Trainingsschleife. Lassen Sie uns den Verlustwert überprüfen. Die Vorhersage des Modells sollte besser sein als die zufällige Schätzung, additionally sollte der Verlust niedriger sein als -ln(1/50257)≈10,825. Was wir von der Überanpassung eines einzelnen Batches erwarten, ist: Am Anfang liegt der Verlust bei etwa 10,825, dann sinkt er auf quick 0. Nehmen wir einen Batch von (x, y) und führen die Trainingsschleife 50 Mal aus. Ich füge auch ein ähnliches Protokoll hinzu, um die Trainingsgeschwindigkeit zu berechnen.
Wie wir sehen, entspricht der Verlustwert genau unseren Erwartungen und der Trainingsdurchsatz liegt bei etwa 400–500 ok Token/Sek. Das ist bereits 40x schneller als die ursprüngliche Model von Pytorch ohne jegliche Optimierung in Andrejs Video. Beachten Sie, dass wir die Jax-Skripte auf 1 A100 GPU ausführen, was den Hardwareunterschied für den Geschwindigkeitsvergleich beseitigen sollte. Es gibt keine .to(system)
Sachen, um Ihr Modell oder Ihre Daten von der Host-CPU auf die Geräte-GPU zu verschieben, was einer der Vorteile von Jax ist!
Das conflict’s additionally, wir haben es geschafft. In Teil 2 werden wir das Coaching mit weiteren Optimierungen 10x schneller machen …
Teil 2: Die Reise der Trainingsoptimierung zu 1350.000 Tokens/Sek. in einer einzigen GPU!
„Sofern nicht anders angegeben, stammen alle Bilder vom Autor“