Llama von Grund auf: So setzt man ein Paper um, ohne zu verzweifeln
(blog.briankitano.com)- Brian Kitano baut selbst eine verkleinerte Llama mit TinyShakespeare und fasst zusammen, dass man ein Paper sicherer umsetzt, wenn man mit einem kleinen Modell beginnt, Komponenten einzeln austauscht und nach jedem Schritt trainiert und evaluiert
- Zuerst richtet er Hilfsfunktionen zur Validierung ein, etwa für Datensplitting, Batch-Erzeugung, Loss-Auswertung und Generierungsfunktionen, prüft dann mit einem einfachen Modell, ob Kompilierung und Training funktionieren, und fügt erst danach Llama-Bausteine hinzu
- RMSNorm, RoPE und SwiGLU werden der Reihe nach ergänzt, wobei mit Tensor-Shapes, Eigenschaften der Formeln und Attention-Maps überprüft wird, ob jede Schicht wie erwartet arbeitet
- Entfernt man in der RoPE-Attention die causal mask, sinkt der Validierungs-Loss zwar bis auf 0,16, die Generierungsqualität wird jedoch schlechter; Ursache ist Information Leakage durch den Blick auf zukünftige Tokens
- Die finale verkleinerte Llama hat 4 Blöcke und rund 2,37 Mio. Parameter, senkt den Validierungs-Loss auf etwa 1,0 und zeigt, dass auch Gradient-Flow und Learning-Rate-Schedule gemeinsam geprüft werden müssen
Klein anfangen und iterativ Sicherheit gewinnen
- Beim Umsetzen eines Papers ist entscheidend, mit einem kleinen Modell zu beginnen, Komponenten einzeln auszutauschen und nach jeder Änderung Training und Evaluation zu wiederholen
- Zuerst werden Hilfsfunktionen vorbereitet, mit denen sich das Modell quantitativ prüfen lässt
- Datensplitting
- Trainingsloop
- Loss-Visualisierung
- Auswertung des Validierungs-Loss
- Statt alle Paper-Bausteine auf einmal zu übernehmen, wird auch eine qualitative Evaluationsfunktion eingerichtet, die mit einem einfachen und schnellen Modell, für das bereits Implementierungserfahrung vorliegt, die Generierungsergebnisse sichtbar macht
- Tensor-Layer werden mit
.shape,assertundplt.imshowgeprüft; statt sofort in die Optimierung von Matrixmultiplikationen einzusteigen, werden erwartete Ergebnisse zunächst von Hand gegengeprüft und anschließend mittorcheffizient umgesetzt - Man sollte mit Batch-Größe, Sequenzlänge und Embedding-Dimension variieren; Code, der nur für eine einzige Größe korrekt ist, kann beim Inferenzzeitpunkt scheitern
Datensatz und Grundeinstellungen
- Ziel der Implementierung ist eine stark verkleinerte Version von Meta AIs Llama, trainiert auf TinyShakespeare
- Llama wird ursprünglich mit 1,4T Tokens trainiert, hier kommt jedoch TinyShakespeare mit rund 1,11 Mio. Zeichen zum Einsatz
- Die originale Llama verwendet einen SentencePiece Byte-Pair-Encoding-Tokenizer, diese Implementierung nutzt jedoch einen einfachen zeichenbasierten Tokenizer
- vocabulary size ist 65
- Wegen des kleinen Datensatzes wird die Speicherhaltung nicht gesondert optimiert
- Ein
MASTER_CONFIG-Dictionary verwaltet Modelleinstellungen wievocab_size,batch_size,context_windowundd_model- Ziel ist, Konstanten und Magic Numbers zu reduzieren und den Code lesbarer zu machen
- Die Funktion
get_batchesteilt die Daten in 80 % Train, 10 % Val und 10 % Test auf und erzeugt von zufälligen Startpunkten Eingabenxsowie Labelsy, die um ein Zeichen verschoben sind
Mit einem Basismodell Kompilierung und Training prüfen
- Das erste Modell ist
SimpleBrokenModel, bestehend aus Embeddings und einem einfachen Feed-Forward-Netzwerknn.EmbeddingLinearReLULinear
- Bei der Umsetzung eines Papers bedeutet „das Modell funktioniert“, dass beide Bedingungen erfüllt sind
- Kompilierung: Tensor-Shapes passen zwischen den Layern zusammen
- Training: Der Loss sinkt tatsächlich
- Die Funktion
evaluate_losssampelt auf den Splits Train und Val jeweils 10 Batches und berechnet den mittleren Loss SimpleBrokenModelerreichte nach 1000 Epochen einen Validierungs-Loss von 3,94 und verbesserte sich gegenüber der anfänglichen Cross-Entropy von 4,17 kaum- Ursache war, dass an
F.cross_entropybereits per Softmax verarbeitete Werte übergeben wurden- PyTorchs
F.cross_entropyerwartet direkt unnormalisierte Logits SimpleModelohne Softmax senkte den Validierungs-Loss auf etwa 2,51
- PyTorchs
- Danach wurde eine
generate-Funktion ergänzt, um die vom Modell erzeugten Zeichen direkt zu prüfen; das Basismodell ist zwar unvollständig, aber der Validierungs-Loss sinkt nun
Llama-Baustein 1: RMSNorm
- Llama verwendet gegenüber dem ursprünglichen Transformer drei zentrale Architekturänderungen
- RMSNorm pre-normalization
- Rotary embeddings
- SwiGLU activation function
- Der ursprüngliche Transformer nutzt BatchNormalization, Llama dagegen RMSNorm, das Vektoren nicht zentriert, sondern über die Varianz skaliert
- Während der ursprüngliche Transformer Normalisierung als Post-Normalization auf die Ausgabe der Attention-Layer anwendet, setzt Llama auf Pre-Normalization direkt auf der Eingabe
- Das implementierte
RMSNormgeht von einer Eingabeform(batch, seq_len, d_model)aus - Das RMSNorm-Ergebnis wird über die Eigenschaft getestet, dass die Layer-Norm der Quadratwurzel aus der Anzahl der Layer-Elemente entspricht
assert- zeilenweiser Vergleich
torch.allclose
SimpleModel_RMS, das RMSNorm zum Basismodell hinzufügt, senkt den Validierungs-Loss leicht auf etwa 2,5015
Llama-Baustein 2: RoPE und causal mask
- RoPE ist ein Verfahren zur Positionskodierung für Transformer und stellt Token-Positionen als Rotationen im Embedding-Raum dar
get_rotary_matrixerzeugt für Kontextfenster und Embedding-Dimension positionsabhängige Rotationsmatrizen- Die RoPE-Implementierung wird anhand der folgenden Eigenschaft getestet
- Das Skalarprodukt zweier an den Positionen
mundnrotierter Vektoren muss einer Rotation um die relative Positionn-mentsprechen
- Das Skalarprodukt zweier an den Positionen
RoPEAttentionHeaderzeugtw_q,w_kundw_v, wendet RoPE-Rotationen auf Query und Key an und nutzt anschließendF.scaled_dot_product_attention- Auf Unterschiede der Tensor-Shapes zwischen Trainings- und Inferenzzeitpunkt muss geachtet werden
- Beim Training passen die Formen oft zu den Konfigurationswerten wie
(config['batch_size'], config['context_window'], config['d_model']) - Bei der Inferenz kann ein einzelnes Beispiel wie
(1, 1, config['d_model'])verarbeitet werden - Innerhalb von
forwardsollte beim Indexing nicht von Modellkonstanten, sondern von den Shapes der Eingabe ausgegangen werden
- Beim Training passen die Formen oft zu den Konfigurationswerten wie
- Das Modell mit ergänzter RoPE-Multi-Head-Attention, aber ohne causal mask, senkte den Validierungs-Loss stark auf 0,1623, erzeugte jedoch schlechte Ausgaben wie
OOOO...oderIIII... - Ein Blick auf die Attention-Map zeigte, dass alle Positionen auf alle Positionen zugreifen konnten; bei der Vorhersage des nächsten Tokens entstand dadurch Information Leakage durch zukünftige Tokens
- Nach dem Wechsel auf
RoPEMaskedAttentionHeadmitis_causal=TrueinF.scaled_dot_product_attentionwurde die obere Dreiecksmatrix der Attention in Richtung Zukunft nahezu 0 - Nach Anwendung der causal mask lag der Validierungs-Loss bei 2,0815 und sank bei längerem Training weiter auf 1,8985
Llama-Baustein 3: SwiGLU und das Stapeln von Blöcken
- Llama ersetzt die ReLU-Nichtlinearität durch die SwiGLU activation function
- Das implementierte
SwiGLUist eine Swish-gated linear unit und verwendet zwei lineare Transformationen sowie einen lernbaren Parameterbeta RopeModelmit SwiGLU im Feed-Forward-Teil hatte 592.706 Parameter und erreichte einen Validierungs-Loss von etwa 1,8963- Danach wird ein
LlamaBlockerstellt, der folgende Struktur in einem Block bündelt- RMSNorm pre-normalization
- masked RoPE multi-head attention
- residual connection
- RMSNorm pre-normalization
- SwiGLU feed-forward
- residual connection
- Das finale
Llama-Modell setztn_layers=4und stapelt mitnn.Sequentialauf Basis vonOrderedDictvierLlamaBlocks - Das finale Modell hat 2.370.246 Parameter; die Trainingsergebnisse lauten
- nach dem ersten Training des 4-Layer-Modells Validierungs-Loss 1,5532
- nach weiterem Training über 10.000 Epochen Validierungs-Loss 1,1479
- nach zusätzlichem Training Validierungs-Loss 0,9997
- Loss eines Batches aus dem Test-Split: 1,2358
Generierungsergebnisse und Debugging-Prüfpunkte
- Das finale Modell erzeugt Shakespeare-ähnliche Namen, Zeilenumbrüche und Wortfragmente, die tatsächliche Satzqualität bleibt jedoch begrenzt
- Der Cross-Entropy-Loss lässt sich aus Sicht der Token-Auswahl intuitiv deuten
- Der Anfangs-Loss von 4,17 entspricht bei einer vocabulary size von 65 ungefähr einer Zufallsauswahl
- Ein Loss von 1,08 lässt sich so interpretieren, als würde zufällig aus etwa 2,9 Tokens gewählt
- Der Gradient-Flow wird mit der Funktion
show_gradsüberprüft- Für jeden Parameter wird der Anteil an Gradienten mit kleinem Absolutwert berechnet
- Wenn bei den meisten Parametern die Gradienten nicht nahe 0 liegen, ist der Flow in gutem Zustand
- Die originale Llama verwendet ein Cosine Annealing learning schedule, in dieser Implementierung verschlechterten sich die Ergebnisse damit jedoch
- In den Cosine-Annealing-Experimenten erhielt der Attention-Bias selbst bei sehr niedriger Toleranz kaum Signal; die Ursache ist unklar, daher ist es in der Praxis sicherer, zunächst einfach zu beginnen
1 Kommentare
Meinungen auf Hacker News
Es scheint einen Bug in der SwiGLU-Implementierung zu geben: In der Referenzarbeit ist Beta im Feed-forward Network kein lernbarer Wert, sondern eine Konstante, und es wird als
FFnSwiGLU = Swish1...gesetztGrundlage ist Gleichung 6 in https://arxiv.org/pdf/2002.05202.pdf
Auch in der offiziellen Llama-Implementierung ist das konstante Beta entfernt: https://github.com/facebookresearch/llama/blob/main/llama/mo...
In den Blog-Logs sieht man an den Zeilen
"feedforward.1.beta', 0.0", dass Beta während des Trainings auf 0 degeneriert ist; eigentlich sollte es die Konstante 1 seinOft passt sich das Netzwerk an Änderungen an, ob beabsichtigt oder nicht, und nach dem Training verhalten sich verschiedene Architekturvarianten teils ähnlich, sodass unklar sein kann, ob sie exakt mit dem Original übereinstimmen müssen
Eine Methode, solche Fehler zu finden, ist, die Ausgaben exakt mit einer Referenzimplementierung abzugleichen. Wie bei den tiny-random-Modellen von HuggingFace müssen die Ausgaben selbst mit zufälligen Gewichten exakt gleich sein; wenn nicht, ist das ein Bug-Signal
Allerdings funktioniert diese Methode vor allem bei Bugs, die während der Inferenz auftreten; Probleme, die nur bei Datenverarbeitung, Optimizer oder Training entstehen, sind schwerer zu finden
Persönlich vermute ich, dass es an ihrer autoregressiven und ODE-artigen Eigenschaft liegt, aber sicher bin ich mir nicht
Die Arbeit ist hervorragend, aber in den frühen
SimpleBrokenModelundSimpleModelgibt es ziemlich viele verschwendete Operationen. Die Reihenfolge istembedding 65 -> 128,linear 128 -> 128,ReLU,linear 128 -> 65; zwischen den ersten beiden Schichten gibt es keine Nichtlinearität, und da beide linear sind, ist die zweite lineare Schicht im Grunde nutzlosDieses Modell entspricht letztlich einem klassischen MLP mit einer einzelnen Hidden Layer; gemessen an FLOPS werden
128*128=16kOperationen von insgesamt128*128+65*128=24kverschwendetDie Embedding-Schicht ist eine spezielle Struktur, die Token-Indizes in Embedding-Vektoren umwandelt, daher kann man sie wohl nicht entfernen
Insgesamt zeigt es die Grundprinzipien sehr gut. Besonders gefällt mir „Benutze
.shapereligiös.assertundplt.imshowsind deine Freunde“, und Vor- und Nachbedingungen von Shapes sollte man immer per assert prüfenIch frage mich auch, ob
bearodertypeguardsolche Prüfungen per Decorator unterstützenAllerdings scheint mir der Teil „Wähle ein kleines, einfaches und schnelles Modell und schreibe Helper, um es qualitativ zu evaluieren“ eher quantitative Evaluation zu meinen. Nur so bekommt man eine numerische Baseline, mit der man fortgeschrittenere Techniken vergleichen kann
Auch der Rat, die Komponenten eines Papers einzeln zu implementieren, sollte präziser sein. Papers probieren meist mehrere Änderungen auf einmal aus und zeigen dann mit Ablation Studies den Beitrag jedes Elements. Daher halte ich es für besser, mit den zentralen Architekturänderungen zu beginnen und dann in der Reihenfolge der größten Effekte aus den Ablations, unter Beachtung der Abhängigkeiten, jede atomare Änderung zu evaluieren
bearodertypeguardkann man dank https://peps.python.org/pep-0646/ manches direkt in Python-Typannotationen hineinpressenZum Beispiel kann man Shapes pro Achse in Typen wie
ndarray[float, Dim1, *Shape]ausdrücken und die Rückgabe-Shape abhängig vomaxis-Wert überladenbear/typeguardgrundlegende Runtime-Prüfungen von Matrix-ShapesTrotzdem dürfte Python kaum so gut sein wie Julia. Julias Typsystem kann viel leichter garantieren, dass Matrixgrößen zusammenpassen
Ich frage mich, nach welchem Prinzip SwiGLU statt ReLU verwendet wird. Ich weiß nicht, ob die Autoren einfach alle möglichen Nichtlinearitäten ausprobiert haben oder ob es einen tieferen Grund gibt
bearblog wird gerade DDoS angegriffen, daher hier das Repository: https://github.com/bkitano/llama-from-scratch
Aus der Perspektive von jemandem, der KI lernt, habe ich die im Artikel vorkommenden Begriffe kurz zusammengefasst. Token sind ganzzahlige Bezeichner für Textstücke; bei LLMs werden innerhalb eines begrenzten Vokabularumfangs häufig verwendete Zeichenfragmente zusammengefasst
Eine Verlustfunktion ist ein Wert, der die Differenz zwischen Vorhersage und korrekter Antwort misst; je niedriger, desto besser. PyTorch ist eine Bibliothek für den Umgang mit Tensoren und neuronalen Netzen, und ein Tensor ist ein mehrdimensionales Zahlen-Array, das Skalare, Vektoren und Matrizen umfasst
Ein neuronales Netz ist eine Verbindungsstruktur aus Neuronen mit Gewichten und Biases, und ein Linear Layer ist eine einfache Struktur, bei der alle Eingaben und Ausgaben verbunden sind. ReLU ist eine Aktivierungsfunktion wie
Math.max(0, x); wenn man nur Linear Layer stapelt, entspricht das am Ende einer einzigen linearen Funktion, daher fügt man Nichtlinearität hinzu, um die Lernfähigkeit zu erhöhenEin Gradient ist eine numerische Änderungsgröße, die während des Trainings berechnet wird, um das Modell genauer zu machen, und Batch-Normalisierung ist eine Methode, die laufenden Zahlen anzupassen, um das Training zu unterstützen. Positional Encoding teilt die relativen Positionen der Tokens als Vektoren mit
Der Operator
@in Python ist ein Alias für__matmul__und wird für Matrixmultiplikation verwendet. Eine Epoche bedeutet, den gesamten Datensatz einmal zu trainieren, und ein Batch ist die Anzahl der Datenpunkte, die vor einer Parameteraktualisierung auf einmal eingespeist werdenAttention ist der Kernmechanismus, der LLMs funktionieren lässt: Eingabetokens werden parallel verarbeitet, daraus werden Zwischentensoren erzeugt, die anschließend zum Generieren der Ausgabetokens verwendet werden
Zum Beispiel könnte
writ, das inwriting,writtenundwritergemeinsam vorkommt, ein einzelnes Token sein, undwriterkönnte inwritundertokenisiert werdenEmbedding ist der Schritt, der solche Tokens in eindeutige numerische Repräsentationen umwandelt
Wenn es eine bestehende Implementierung und Checkpoints eines Modells gibt, ist die effektivste Methode, die eigene Implementierung zu prüfen, diesen Checkpoint zu laden und die Ausgabewerte zu vergleichen
Wenn die Ausgabe nicht stimmt, ist meist ein Detail der Implementierung falsch; man kann dann systematisch jede Schicht nachverfolgen und die tatsächliche Abweichung finden. Dabei entdeckt man mitunter auch Merkwürdigkeiten in der bestehenden Implementierung
Das bezieht sich auf das Modell selbst; Training ist eine separate Achse. Wenn man die Hyperparameter aber einigermaßen ähnlich eingestellt hat, läuft es bei korrekter Modellimplementierung in der Regel ganz ordentlich
Sowohl die Hinweise zum Lesen von Papers als auch der Inhalt dieses Papers sind gut, und auch Karpathys Makemore-Serie ist empfehlenswert
Die zusammengefassten Ratschläge sind sehr gut, und der Tipp, Tensor-Shapes mit assert zu prüfen, gilt meiner Meinung nach für jede allgemeine lineare-Algebra-Bibliothek. Beim Schreiben komplexen linearen-Algebra-Codes ist es sehr wichtig, in kleinen Schritten vorzugehen und defensiv zu programmieren
Lineare Algebra in Mainstream-Sprachen zu programmieren ist schrecklich, weil es keine Compile-Time-Shape-Prüfung gibt. Die Shape eines Tensors sollte Teil des Typs sein, und wenn man versucht,
3x4und3x4ohne Transposition zu multiplizieren, sollte das schon gar nicht kompilierenNach einer langen Berechnung an einer Operation mit Dimensionskonflikt zu scheitern, ist wirklich das Schlimmste
Ich finde auch, dass bei PyTorch-Tensoren das Gerät statisch typisiert sein sollte. Derzeit bekommt man einen Laufzeitfehler, wenn man versucht, einen Tensor im CPU-Speicher mit einem Tensor im GPU-Speicher zu multiplizieren