2 Punkte von GN⁺ 2023-08-10 | 1 Kommentare | Auf WhatsApp teilen
  • Brian Kitano baut selbst eine verkleinerte Llama mit TinyShakespeare und fasst zusammen, dass man ein Paper sicherer umsetzt, wenn man mit einem kleinen Modell beginnt, Komponenten einzeln austauscht und nach jedem Schritt trainiert und evaluiert
  • Zuerst richtet er Hilfsfunktionen zur Validierung ein, etwa für Datensplitting, Batch-Erzeugung, Loss-Auswertung und Generierungsfunktionen, prüft dann mit einem einfachen Modell, ob Kompilierung und Training funktionieren, und fügt erst danach Llama-Bausteine hinzu
  • RMSNorm, RoPE und SwiGLU werden der Reihe nach ergänzt, wobei mit Tensor-Shapes, Eigenschaften der Formeln und Attention-Maps überprüft wird, ob jede Schicht wie erwartet arbeitet
  • Entfernt man in der RoPE-Attention die causal mask, sinkt der Validierungs-Loss zwar bis auf 0,16, die Generierungsqualität wird jedoch schlechter; Ursache ist Information Leakage durch den Blick auf zukünftige Tokens
  • Die finale verkleinerte Llama hat 4 Blöcke und rund 2,37 Mio. Parameter, senkt den Validierungs-Loss auf etwa 1,0 und zeigt, dass auch Gradient-Flow und Learning-Rate-Schedule gemeinsam geprüft werden müssen

Klein anfangen und iterativ Sicherheit gewinnen

  • Beim Umsetzen eines Papers ist entscheidend, mit einem kleinen Modell zu beginnen, Komponenten einzeln auszutauschen und nach jeder Änderung Training und Evaluation zu wiederholen
  • Zuerst werden Hilfsfunktionen vorbereitet, mit denen sich das Modell quantitativ prüfen lässt
    • Datensplitting
    • Trainingsloop
    • Loss-Visualisierung
    • Auswertung des Validierungs-Loss
  • Statt alle Paper-Bausteine auf einmal zu übernehmen, wird auch eine qualitative Evaluationsfunktion eingerichtet, die mit einem einfachen und schnellen Modell, für das bereits Implementierungserfahrung vorliegt, die Generierungsergebnisse sichtbar macht
  • Tensor-Layer werden mit .shape, assert und plt.imshow geprüft; statt sofort in die Optimierung von Matrixmultiplikationen einzusteigen, werden erwartete Ergebnisse zunächst von Hand gegengeprüft und anschließend mit torch effizient umgesetzt
  • Man sollte mit Batch-Größe, Sequenzlänge und Embedding-Dimension variieren; Code, der nur für eine einzige Größe korrekt ist, kann beim Inferenzzeitpunkt scheitern

Datensatz und Grundeinstellungen

  • Ziel der Implementierung ist eine stark verkleinerte Version von Meta AIs Llama, trainiert auf TinyShakespeare
  • Llama wird ursprünglich mit 1,4T Tokens trainiert, hier kommt jedoch TinyShakespeare mit rund 1,11 Mio. Zeichen zum Einsatz
  • Die originale Llama verwendet einen SentencePiece Byte-Pair-Encoding-Tokenizer, diese Implementierung nutzt jedoch einen einfachen zeichenbasierten Tokenizer
    • vocabulary size ist 65
    • Wegen des kleinen Datensatzes wird die Speicherhaltung nicht gesondert optimiert
  • Ein MASTER_CONFIG-Dictionary verwaltet Modelleinstellungen wie vocab_size, batch_size, context_window und d_model
    • Ziel ist, Konstanten und Magic Numbers zu reduzieren und den Code lesbarer zu machen
  • Die Funktion get_batches teilt die Daten in 80 % Train, 10 % Val und 10 % Test auf und erzeugt von zufälligen Startpunkten Eingaben x sowie Labels y, die um ein Zeichen verschoben sind

Mit einem Basismodell Kompilierung und Training prüfen

  • Das erste Modell ist SimpleBrokenModel, bestehend aus Embeddings und einem einfachen Feed-Forward-Netzwerk
    • nn.Embedding
    • Linear
    • ReLU
    • Linear
  • Bei der Umsetzung eines Papers bedeutet „das Modell funktioniert“, dass beide Bedingungen erfüllt sind
    • Kompilierung: Tensor-Shapes passen zwischen den Layern zusammen
    • Training: Der Loss sinkt tatsächlich
  • Die Funktion evaluate_loss sampelt auf den Splits Train und Val jeweils 10 Batches und berechnet den mittleren Loss
  • SimpleBrokenModel erreichte nach 1000 Epochen einen Validierungs-Loss von 3,94 und verbesserte sich gegenüber der anfänglichen Cross-Entropy von 4,17 kaum
  • Ursache war, dass an F.cross_entropy bereits per Softmax verarbeitete Werte übergeben wurden
    • PyTorchs F.cross_entropy erwartet direkt unnormalisierte Logits
    • SimpleModel ohne Softmax senkte den Validierungs-Loss auf etwa 2,51
  • Danach wurde eine generate-Funktion ergänzt, um die vom Modell erzeugten Zeichen direkt zu prüfen; das Basismodell ist zwar unvollständig, aber der Validierungs-Loss sinkt nun

Llama-Baustein 1: RMSNorm

  • Llama verwendet gegenüber dem ursprünglichen Transformer drei zentrale Architekturänderungen
    • RMSNorm pre-normalization
    • Rotary embeddings
    • SwiGLU activation function
  • Der ursprüngliche Transformer nutzt BatchNormalization, Llama dagegen RMSNorm, das Vektoren nicht zentriert, sondern über die Varianz skaliert
  • Während der ursprüngliche Transformer Normalisierung als Post-Normalization auf die Ausgabe der Attention-Layer anwendet, setzt Llama auf Pre-Normalization direkt auf der Eingabe
  • Das implementierte RMSNorm geht von einer Eingabeform (batch, seq_len, d_model) aus
  • Das RMSNorm-Ergebnis wird über die Eigenschaft getestet, dass die Layer-Norm der Quadratwurzel aus der Anzahl der Layer-Elemente entspricht
    • assert
    • zeilenweiser Vergleich
    • torch.allclose
  • SimpleModel_RMS, das RMSNorm zum Basismodell hinzufügt, senkt den Validierungs-Loss leicht auf etwa 2,5015

Llama-Baustein 2: RoPE und causal mask

  • RoPE ist ein Verfahren zur Positionskodierung für Transformer und stellt Token-Positionen als Rotationen im Embedding-Raum dar
  • get_rotary_matrix erzeugt für Kontextfenster und Embedding-Dimension positionsabhängige Rotationsmatrizen
  • Die RoPE-Implementierung wird anhand der folgenden Eigenschaft getestet
    • Das Skalarprodukt zweier an den Positionen m und n rotierter Vektoren muss einer Rotation um die relative Position n-m entsprechen
  • RoPEAttentionHead erzeugt w_q, w_k und w_v, wendet RoPE-Rotationen auf Query und Key an und nutzt anschließend F.scaled_dot_product_attention
  • Auf Unterschiede der Tensor-Shapes zwischen Trainings- und Inferenzzeitpunkt muss geachtet werden
    • Beim Training passen die Formen oft zu den Konfigurationswerten wie (config['batch_size'], config['context_window'], config['d_model'])
    • Bei der Inferenz kann ein einzelnes Beispiel wie (1, 1, config['d_model']) verarbeitet werden
    • Innerhalb von forward sollte beim Indexing nicht von Modellkonstanten, sondern von den Shapes der Eingabe ausgegangen werden
  • Das Modell mit ergänzter RoPE-Multi-Head-Attention, aber ohne causal mask, senkte den Validierungs-Loss stark auf 0,1623, erzeugte jedoch schlechte Ausgaben wie OOOO... oder IIII...
  • Ein Blick auf die Attention-Map zeigte, dass alle Positionen auf alle Positionen zugreifen konnten; bei der Vorhersage des nächsten Tokens entstand dadurch Information Leakage durch zukünftige Tokens
  • Nach dem Wechsel auf RoPEMaskedAttentionHead mit is_causal=True in F.scaled_dot_product_attention wurde die obere Dreiecksmatrix der Attention in Richtung Zukunft nahezu 0
  • Nach Anwendung der causal mask lag der Validierungs-Loss bei 2,0815 und sank bei längerem Training weiter auf 1,8985

Llama-Baustein 3: SwiGLU und das Stapeln von Blöcken

  • Llama ersetzt die ReLU-Nichtlinearität durch die SwiGLU activation function
  • Das implementierte SwiGLU ist eine Swish-gated linear unit und verwendet zwei lineare Transformationen sowie einen lernbaren Parameter beta
  • RopeModel mit SwiGLU im Feed-Forward-Teil hatte 592.706 Parameter und erreichte einen Validierungs-Loss von etwa 1,8963
  • Danach wird ein LlamaBlock erstellt, der folgende Struktur in einem Block bündelt
    • RMSNorm pre-normalization
    • masked RoPE multi-head attention
    • residual connection
    • RMSNorm pre-normalization
    • SwiGLU feed-forward
    • residual connection
  • Das finale Llama-Modell setzt n_layers=4 und stapelt mit nn.Sequential auf Basis von OrderedDict vier LlamaBlocks
  • Das finale Modell hat 2.370.246 Parameter; die Trainingsergebnisse lauten
    • nach dem ersten Training des 4-Layer-Modells Validierungs-Loss 1,5532
    • nach weiterem Training über 10.000 Epochen Validierungs-Loss 1,1479
    • nach zusätzlichem Training Validierungs-Loss 0,9997
    • Loss eines Batches aus dem Test-Split: 1,2358

Generierungsergebnisse und Debugging-Prüfpunkte

  • Das finale Modell erzeugt Shakespeare-ähnliche Namen, Zeilenumbrüche und Wortfragmente, die tatsächliche Satzqualität bleibt jedoch begrenzt
  • Der Cross-Entropy-Loss lässt sich aus Sicht der Token-Auswahl intuitiv deuten
    • Der Anfangs-Loss von 4,17 entspricht bei einer vocabulary size von 65 ungefähr einer Zufallsauswahl
    • Ein Loss von 1,08 lässt sich so interpretieren, als würde zufällig aus etwa 2,9 Tokens gewählt
  • Der Gradient-Flow wird mit der Funktion show_grads überprüft
    • Für jeden Parameter wird der Anteil an Gradienten mit kleinem Absolutwert berechnet
    • Wenn bei den meisten Parametern die Gradienten nicht nahe 0 liegen, ist der Flow in gutem Zustand
  • Die originale Llama verwendet ein Cosine Annealing learning schedule, in dieser Implementierung verschlechterten sich die Ergebnisse damit jedoch
  • In den Cosine-Annealing-Experimenten erhielt der Attention-Bias selbst bei sehr niedriger Toleranz kaum Signal; die Ursache ist unklar, daher ist es in der Praxis sicherer, zunächst einfach zu beginnen

1 Kommentare

 
GN⁺ 2023-08-10
Meinungen auf Hacker News
  • Es scheint einen Bug in der SwiGLU-Implementierung zu geben: In der Referenzarbeit ist Beta im Feed-forward Network kein lernbarer Wert, sondern eine Konstante, und es wird als FFnSwiGLU = Swish1... gesetzt
    Grundlage ist Gleichung 6 in https://arxiv.org/pdf/2002.05202.pdf
    Auch in der offiziellen Llama-Implementierung ist das konstante Beta entfernt: https://github.com/facebookresearch/llama/blob/main/llama/mo...
    In den Blog-Logs sieht man an den Zeilen "feedforward.1.beta', 0.0", dass Beta während des Trainings auf 0 degeneriert ist; eigentlich sollte es die Konstante 1 sein

    • Das zeigt, wie schwierig es ist, ein Transformer-Neuronales Netz korrekt zu implementieren. Man kann an vielen Stellen Fehler machen, und meist zeigt sich das nur als „etwas schlechtere Performance als ursprünglich“, sodass es schwer ist, es sicher zu erkennen
      Oft passt sich das Netzwerk an Änderungen an, ob beabsichtigt oder nicht, und nach dem Training verhalten sich verschiedene Architekturvarianten teils ähnlich, sodass unklar sein kann, ob sie exakt mit dem Original übereinstimmen müssen
      Eine Methode, solche Fehler zu finden, ist, die Ausgaben exakt mit einer Referenzimplementierung abzugleichen. Wie bei den tiny-random-Modellen von HuggingFace müssen die Ausgaben selbst mit zufälligen Gewichten exakt gleich sein; wenn nicht, ist das ein Bug-Signal
      Allerdings funktioniert diese Methode vor allem bei Bugs, die während der Inferenz auftreten; Probleme, die nur bei Datenverarbeitung, Optimizer oder Training entstehen, sind schwerer zu finden
    • Ich denke, dass Bias-Werte in Transformern im Allgemeinen eher schlecht passen
      Persönlich vermute ich, dass es an ihrer autoregressiven und ODE-artigen Eigenschaft liegt, aber sicher bin ich mir nicht
  • Die Arbeit ist hervorragend, aber in den frühen SimpleBrokenModel und SimpleModel gibt es ziemlich viele verschwendete Operationen. Die Reihenfolge ist embedding 65 -> 128, linear 128 -> 128, ReLU, linear 128 -> 65; zwischen den ersten beiden Schichten gibt es keine Nichtlinearität, und da beide linear sind, ist die zweite lineare Schicht im Grunde nutzlos
    Dieses Modell entspricht letztlich einem klassischen MLP mit einer einzelnen Hidden Layer; gemessen an FLOPS werden 128*128=16k Operationen von insgesamt 128*128+65*128=24k verschwendet

    • Offenbar bin ich nicht die einzige Person, die Nichtlinearitäten noch lernt. Ich frage mich, ob die beste Korrektur hier darin besteht, zwischen Embedding und erster linearer Schicht ReLU oder SwiGLU einzufügen, oder ob man die lineare Schicht einfach löschen sollte
      Die Embedding-Schicht ist eine spezielle Struktur, die Token-Indizes in Embedding-Vektoren umwandelt, daher kann man sie wohl nicht entfernen
  • Insgesamt zeigt es die Grundprinzipien sehr gut. Besonders gefällt mir „Benutze .shape religiös. assert und plt.imshow sind deine Freunde“, und Vor- und Nachbedingungen von Shapes sollte man immer per assert prüfen
    Ich frage mich auch, ob bear oder typeguard solche Prüfungen per Decorator unterstützen
    Allerdings scheint mir der Teil „Wähle ein kleines, einfaches und schnelles Modell und schreibe Helper, um es qualitativ zu evaluieren“ eher quantitative Evaluation zu meinen. Nur so bekommt man eine numerische Baseline, mit der man fortgeschrittenere Techniken vergleichen kann
    Auch der Rat, die Komponenten eines Papers einzeln zu implementieren, sollte präziser sein. Papers probieren meist mehrere Änderungen auf einmal aus und zeigen dann mit Ablation Studies den Beitrag jedes Elements. Daher halte ich es für besser, mit den zentralen Architekturänderungen zu beginnen und dann in der Reihenfolge der größten Effekte aus den Ablations, unter Beachtung der Abhängigkeiten, jede atomare Änderung zu evaluieren

    • Statt bear oder typeguard kann man dank https://peps.python.org/pep-0646/ manches direkt in Python-Typannotationen hineinpressen
      Zum Beispiel kann man Shapes pro Achse in Typen wie ndarray[float, Dim1, *Shape] ausdrücken und die Rückgabe-Shape abhängig vom axis-Wert überladen
    • PyTorch kenne ich nicht gut, aber als ich zuletzt nachgesehen habe, war es dort nicht so; Jax unterstützt über bear / typeguard grundlegende Runtime-Prüfungen von Matrix-Shapes
      Trotzdem dürfte Python kaum so gut sein wie Julia. Julias Typsystem kann viel leichter garantieren, dass Matrixgrößen zusammenpassen
  • Ich frage mich, nach welchem Prinzip SwiGLU statt ReLU verwendet wird. Ich weiß nicht, ob die Autoren einfach alle möglichen Nichtlinearitäten ausprobiert haben oder ob es einen tieferen Grund gibt

    • Wie bei viel Forschung gilt: Wenn es keine klare, durch saubere Forschung gestützte Erklärung gibt, ist es wahrscheinlich, dass sie zufällig per Hill-Climbing einzeilige Änderungen ausprobiert haben, die cool aussahen, und aufgehört haben, als es Zeit wurde, das Paper zu schreiben und Ablation Studies zu machen
  • bearblog wird gerade DDoS angegriffen, daher hier das Repository: https://github.com/bkitano/llama-from-scratch

  • Aus der Perspektive von jemandem, der KI lernt, habe ich die im Artikel vorkommenden Begriffe kurz zusammengefasst. Token sind ganzzahlige Bezeichner für Textstücke; bei LLMs werden innerhalb eines begrenzten Vokabularumfangs häufig verwendete Zeichenfragmente zusammengefasst
    Eine Verlustfunktion ist ein Wert, der die Differenz zwischen Vorhersage und korrekter Antwort misst; je niedriger, desto besser. PyTorch ist eine Bibliothek für den Umgang mit Tensoren und neuronalen Netzen, und ein Tensor ist ein mehrdimensionales Zahlen-Array, das Skalare, Vektoren und Matrizen umfasst
    Ein neuronales Netz ist eine Verbindungsstruktur aus Neuronen mit Gewichten und Biases, und ein Linear Layer ist eine einfache Struktur, bei der alle Eingaben und Ausgaben verbunden sind. ReLU ist eine Aktivierungsfunktion wie Math.max(0, x); wenn man nur Linear Layer stapelt, entspricht das am Ende einer einzigen linearen Funktion, daher fügt man Nichtlinearität hinzu, um die Lernfähigkeit zu erhöhen
    Ein Gradient ist eine numerische Änderungsgröße, die während des Trainings berechnet wird, um das Modell genauer zu machen, und Batch-Normalisierung ist eine Methode, die laufenden Zahlen anzupassen, um das Training zu unterstützen. Positional Encoding teilt die relativen Positionen der Tokens als Vektoren mit
    Der Operator @ in Python ist ein Alias für __matmul__ und wird für Matrixmultiplikation verwendet. Eine Epoche bedeutet, den gesamten Datensatz einmal zu trainieren, und ein Batch ist die Anzahl der Datenpunkte, die vor einer Parameteraktualisierung auf einmal eingespeist werden
    Attention ist der Kernmechanismus, der LLMs funktionieren lässt: Eingabetokens werden parallel verarbeitet, daraus werden Zwischentensoren erzeugt, die anschließend zum Generieren der Ausgabetokens verwendet werden

    • Außerhalb des Fachgebiets wissen manche vielleicht nicht, was „Karpathy“ bedeutet. Wenn man Andrej Karpathy mit Kontext vorstellt, etwa als „Wissenschaftskommunikator und Forscher“, wird klarer, dass man auf seine Texte oder Videos verweist
    • Ein Token ist für Anfänger genauer betrachtet weniger einfach ein ganzzahliger Bezeichner für ein Textstück, sondern eher ein Wortfragment, das häufig genug ist, um für sich genommen nützlich zu sein
      Zum Beispiel könnte writ, das in writing, written und writer gemeinsam vorkommt, ein einzelnes Token sein, und writer könnte in writ und er tokenisiert werden
      Embedding ist der Schritt, der solche Tokens in eindeutige numerische Repräsentationen umwandelt
    • Wenn man lineare Funktionen zusammensetzt, erhält man wieder eine lineare Funktion. Wenn also alles linear ist, sind bei mehreren gestapelten Schichten alle bis auf eine verschwendet; um das zu vermeiden, braucht man Nichtlinearität
    • Neben Karpathys Videoserie und dem accompanying repo würde mich interessieren, ob es auf dem Lernweg noch weitere besonders hilfreiche Materialien oder Bücher gab
    • Mich interessiert, was Batch-Normalisierung genau tut und wie sie hilft
  • Wenn es eine bestehende Implementierung und Checkpoints eines Modells gibt, ist die effektivste Methode, die eigene Implementierung zu prüfen, diesen Checkpoint zu laden und die Ausgabewerte zu vergleichen
    Wenn die Ausgabe nicht stimmt, ist meist ein Detail der Implementierung falsch; man kann dann systematisch jede Schicht nachverfolgen und die tatsächliche Abweichung finden. Dabei entdeckt man mitunter auch Merkwürdigkeiten in der bestehenden Implementierung
    Das bezieht sich auf das Modell selbst; Training ist eine separate Achse. Wenn man die Hyperparameter aber einigermaßen ähnlich eingestellt hat, läuft es bei korrekter Modellimplementierung in der Regel ganz ordentlich

  • Sowohl die Hinweise zum Lesen von Papers als auch der Inhalt dieses Papers sind gut, und auch Karpathys Makemore-Serie ist empfehlenswert

  • Die zusammengefassten Ratschläge sind sehr gut, und der Tipp, Tensor-Shapes mit assert zu prüfen, gilt meiner Meinung nach für jede allgemeine lineare-Algebra-Bibliothek. Beim Schreiben komplexen linearen-Algebra-Codes ist es sehr wichtig, in kleinen Schritten vorzugehen und defensiv zu programmieren
    Lineare Algebra in Mainstream-Sprachen zu programmieren ist schrecklich, weil es keine Compile-Time-Shape-Prüfung gibt. Die Shape eines Tensors sollte Teil des Typs sein, und wenn man versucht, 3x4 und 3x4 ohne Transposition zu multiplizieren, sollte das schon gar nicht kompilieren
    Nach einer langen Berechnung an einer Operation mit Dimensionskonflikt zu scheitern, ist wirklich das Schlimmste
    Ich finde auch, dass bei PyTorch-Tensoren das Gerät statisch typisiert sein sollte. Derzeit bekommt man einen Laufzeitfehler, wenn man versucht, einen Tensor im CPU-Speicher mit einem Tensor im GPU-Speicher zu multiplizieren