So skalierst du dein Modell: Eine systemische Perspektive auf LLMs auf TPUs
(jax-ml.github.io)- Die Optimierung der Deep-Learning-Performance im großen Maßstab wirkt oft wie „Alchemie“, doch tatsächlich lässt sich die Modelleffizienz mit verständlichen, einfachen Prinzipien steigern
- Von einem einzelnen Beschleuniger bis zu Zehntausenden von Beschleunigern gelten relativ einfache Prinzipien überall. Wer sie versteht, kann damit unter anderem Folgendes tun:
- grob einschätzen, wie nah einzelne Teile eines Modells am theoretischen Optimum liegen
- eine Grundlage schaffen, um bei unterschiedlichen Größenordnungen geeignete Parallelisierungstechniken auszuwählen
- Kosten und Zeitaufwand für Training und Ausführung großer Transformer-Modelle abschätzen
- Algorithmen entwerfen, die die Eigenschaften bestimmter Hardware ausnutzen
- Hardware entwerfen, indem die Grenzen der aktuellen Algorithmusleistung klar verstanden werden
- Erforderliches Vorwissen
- Grundverständnis von LLMs und der Transformer-Architektur erforderlich
- Verständnis großskaliger Betriebsweisen ist nicht zwingend nötig
- Grundwissen zum LLM-Training und Erfahrung mit JAX sind von Vorteil
- Empfohlen werden ein Blogpost zur Transformer-Architektur und Folien zur LLM-Skalierung in JAX
- Ziele
- die Fähigkeit entwickeln, abzuschätzen, wie ein Modell auf gegebener Hardware sinnvoll parallelisiert werden kann
- die Fähigkeit entwickeln, Zeit- und Kostenaufwand für Training und Inferenz grob zu berechnen
Warum man sich dafür interessieren sollte
- Noch vor 3–4 Jahren mussten die meisten ML-Forschenden über solche großskaligen Optimierungen nicht viel wissen
- Heute arbeiten selbst „kleine“ Modelle nahe an den Hardwaregrenzen, weshalb ein Verständnis effizienter Verfahren im großen Maßstab unverzichtbar geworden ist
- Die Geschichte des ML lässt sich als Wechselspiel zwischen Systeminnovationen und Softwareverbesserungen verstehen
- Da aktuelle Transformer-Modelle die Hardwaregrenzen ausreizen, scheitern neue Architekturen oder Forschungsansätze in der Praxis leicht, wenn die Modelleffizienz nicht verstanden wird
- Selbst wenn ein Benchmark 20 % Leistungsgewinn zeigt, ist der praktische Nutzen gering, wenn die Hardwareeffizienz zugleich um 20 % sinkt
- Das zentrale Ziel der Modellskalierung ist es, den Durchsatz beim Hinzufügen weiterer Chips (Beschleuniger) linear zu steigern
- Das wird als „Strong Scaling“ bezeichnet
- Zusätzliche Chips verkürzen die Rechenzeit, verursachen aber Kommunikationskosten zwischen den Chips
- Dauert die Kommunikation länger als die Berechnung, gerät man in einen „Communication-Bound“-Zustand, in dem Strong Scaling nicht mehr möglich ist
- Wer die Hardware gut genug versteht, um vorherzusagen, wo diese Engpässe auftreten, kann Modelle so entwerfen oder umstrukturieren, dass sie vermieden werden
- Ziel dieses Buchs ist es, zu erklären, wie TPU-(und GPU-)Hardware funktioniert und wie sich die Transformer-Architektur so entwickelt hat, dass sie auf heutiger Hardware gut läuft
- Es soll sowohl Forschenden helfen, die neue Architekturen entwerfen, als auch Ingenieurinnen und Ingenieuren, die LLMs der aktuellen Generation möglichst schnell ausführen wollen
Gesamtüberblick
- Dieser Text ist wie folgt aufgebaut
- In Abschnitt 1 wird mit der Roofline-Analyse erklärt, welche Faktoren die Leistungsgrenzen eines Modells bestimmen (Kommunikation, Rechenleistung, Speicher)
- In Abschnitt 2 und Abschnitt 3 geht es um den inneren Aufbau von TPUs und GPUs sowie um die Verbindung zwischen Chips
- Damit werden unter anderem folgende Fragen beantwortet
- Wie schnell kann eine Matrixmultiplikation einer bestimmten Größe theoretisch ausgeführt werden
- Ab welchem Punkt wird die Berechnung durch Speicherbandbreite oder Kommunikationsbandbreite begrenzt
- Wie ist ein TPU-Cluster aufgebaut und wie lange dauert es ungefähr, Daten von einem Chip zu einem anderen zu verschieben
- Wie lassen sich verteilte Matrizen effizient multiplizieren
- Damit werden unter anderem folgende Fragen beantwortet
- In Abschnitt 4 werden die Formeln der Transformer-Architektur (Matrixgrößen, Parameterzahl, FLOPs) im Detail behandelt
- Abschnitt 5 und Abschnitt 7 bilden den Kern und stellen verschiedene Methoden vor, um Modelle über mehrere Chips zu parallelisieren
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Außerdem werden Speicherspartechniken wie ZeRO, Rematerialisation, Host offload und Gradient accumulation behandelt
- Abschnitt 6 und Abschnitt 8 zeigen am Beispiel des LLaMA-3-Modells auf TPUs, wie Training und Inferenz ablaufen und welche realen Kosten, Zeiten und Konfigurationen dabei anfallen
- Abschließend behandeln Abschnitt 9 und Abschnitt 10 praktische Methoden zum Profiling, Debugging und zur Anwendung von Parallelverarbeitung in JAX
Im Detail: Zusammenfassung der wichtigsten Abschnitte des Buchs
-
Teil 1: Preliminaries
-
Abschnitt 1: Einführung in die einfache Roofline-Analyse
- Drei Faktoren begrenzen einen Algorithmus: Rechenleistung, Kommunikation und Speicher
- Daraus lernt man, wie sich die obere Grenze der Rechengeschwindigkeit abschätzen lässt
-
Abschnitt 2: Eine Sicht auf TPUs
- Wie TPUs rechnen
- Was eine Systolic-Array-Struktur ist
- Ein grundlegendes Verständnis dafür, wie TPUs Speicher- und Kommunikationsbandbreite bereitstellen
-
Abschnitt 3: Verteilte Matrizen und verteilte Multiplikation
- Techniken zum Speichern von Modellparametern über mehrere Chips verteilt (Sharding)
- Wie Kommunikation und Engpässe bei verteilten Matrixoperationen behandelt werden
-
-
Teil 2: Transformers
-
Abschnitt 4: Überblick über die benötigten Transformer-Formeln
- Wie Matrixmultiplikationen im Transformer konkret aussehen
- Wie Parameterzahl, FLOPs, Größe des KV-Cache usw. berechnet werden
- Wie viel Rechenaufwand Attention im Vergleich zu Feed-Forward-Blöcken erfordert
-
Abschnitt 5: Parallelisierungsstrategien für das Transformer-Training
- Einführung in Data parallel, Tensor parallel, Pipeline parallel und Expert parallel
- Speichersparmaßnahmen wie ZeRO (FSDP), Rematerialisation, Gradient accumulation und Host offload
- Grundlagen, um die Parallelisierung passend zu Modellgröße und Chip-Anzahl zu konfigurieren
-
Abschnitt 6: Anwendung auf das Training von LLaMA 3 auf TPUs
- Abschätzung von Zeit- und Kostenaufwand unter der Annahme, dass ein LLaMA-3-Modell in einer realen TPU-Umgebung trainiert wird
- Konkrete Beispiele zu Batch-Größe, Parallelisierungsart und Speichernutzung
-
Abschnitt 7: Alles zur Transformer-Inferenz
- Bei der Inferenz tritt Latenz als wichtiger neuer Faktor hinzu
- Speicherverbrauch und Kommunikationsprobleme durch den KV-Cache usw.
- Diskussion darüber, wie mehrere Chips für das Modell-Serving aufgeteilt und verbunden werden sollten
-
Abschnitt 8: Anwendung auf das Serving von LLaMA 3 auf TPUs
- Analyse grober Kosten sowie der Trade-offs zwischen Latenz und Durchsatz unter der Annahme, dass LLaMA 3 auf TPU v5e betrieben wird
-
-
Teil 3: Praktische Tutorials
-
Abschnitt 9: So profiliert man TPU-Code
- Verständnis des JAX+XLA-Stacks
- Erkennen realer Performance-Probleme und möglicher Lösungen
- Einsatz des JAX-/TensorBoard-Profilers
-
Abschnitt 10: TPU-Programmierung mit JAX
- Nutzung der Parallelisierungs-APIs (Primitives) von JAX
- Parallelrechenkonzepte anhand von Beispielen und Aufgaben lernen
-
Abschnitt 11: Fazit und weitere Materialien
- Weiterführende Lektüre zu TPUs und LLMs
- Kurzer Abschluss des Gesamtinhalts mit einem Ausblick auf die Zukunft
-
1 Kommentare
Hacker-News-Kommentare