2 Punkte von GN⁺ 2024-09-24 | 1 Kommentare | Auf WhatsApp teilen

Felafax BlogTune Llama3 405B auf AMD MI300x (unsere Reise)

Einführung

  • Mit dem Wachstum von Open-Source-Modellen steigt der Bedarf an leistungsfähiger Infrastruktur für das Training großer KI-Modelle
  • Felafax hat das Modell LLaMA 3.1 405B auf AMD-GPUs feinabgestimmt und damit die Effizienz von AMD-Hardware demonstriert
  • Die gesamte Arbeit wurde als Open Source auf GitHub veröffentlicht
  • AMD-MI300X-GPUs bieten im Vergleich zu NVIDIA-AI-Hardware eine hohe Leistung
  • Das Projekt wurde durch die Unterstützung von TensorWave ermöglicht

Was ist JAX und warum wurde es gewählt?

  • JAX ist eine leistungsstarke Machine-Learning-Bibliothek, die eine NumPy-ähnliche API, automatische Differenzierung und Googles XLA-Compiler kombiniert
  • Es bietet hervorragende APIs für Modellparallelisierung und ist daher ideal für das Training großer Modelle

Vorteile von JAX

  • Reine Funktionen: JAX fördert das Schreiben reiner Funktionen, wodurch sich Code leichter strukturieren, debuggen und lesen lässt
  • Fortgeschrittene Parallelisierung: Die flexible JIT-API von JAX unterstützt fortgeschrittene Daten- und Modellparallelisierung, die für groß angelegtes Training essenziell ist
  • Saubere Codebasis: Die Designphilosophie von JAX fördert das Schreiben von Code, der zwischen Hardwareplattformen portierbar ist

Warum JAX auf nicht von NVIDIA stammender Hardware herausragt

  • Hardwareunabhängiger Ansatz: JAX nutzt den XLA-Compiler, um Berechnungen in eine hardwareunabhängige Zwischenrepräsentation zu kompilieren
  • Plattformunabhängige Optimierung: Der XLA-Compiler führt Optimierungen unabhängig von der Hardware durch
  • Einfache Portierbarkeit: Mit JAX sind beim Wechsel von NVIDIA zu AMD nur minimale Codeänderungen nötig

JAX auf AMD-GPUs einrichten

  • Ein Docker-Image wurde geladen, der Container gestartet und anschließend die Installation überprüft
  • Das Modell LLaMA 405B wurde mit 8 AMD-MI300X-GPUs trainiert

Training von LLaMA 405B: Leistung und Skalierbarkeit

  • Das Modell LLaMA 405B wurde mit JAX auf AMD-GPUs trainiert
  • Beim LoRA-Feintuning wurden Modellgewichte und LoRA-Parameter mit bfloat16-Präzision angepasst
  • Modellgröße: belegt etwa 800 GB VRAM
  • LoRA-Gewichte und Optimizer-Zustände: belegen etwa 400 GB VRAM
  • Gesamte VRAM-Nutzung: etwa 1200 GB
  • Trainingsgeschwindigkeit: etwa 35 Token pro Sekunde
  • Speichereffizienz: hält etwa 70 %
  • Skalierbarkeit: mit JAX nahezu linear über 8 GPUs skalierbar

Unser Trainings-Setup

  • LLaMA 3.1 wurde von PyTorch nach JAX konvertiert
  • Das Modell wurde durch Modellladen und Parameter-Sharding effizient verteilt

Parameter-Sharding in JAX

  • Mithilfe der Device-Mesh-Funktion von JAX wurde das Modell effizient auf 8 AMD-GPUs verteilt
  • Durch die Definition von Parameter-Sharding-Regeln wurden die Dimensionen jedes Tensors entlang der Mesh-Achsen geshardet

Implementierung des LoRA-Trainings

  • LoRA reduziert die Zahl trainierbarer Parameter, indem Gewichtsaktualisierungen in niedrig-rangige Matrizen zerlegt werden
  • Es wurde eine LoRADense-Schicht implementiert, die LoRA-Parameter enthält
  • LoRA-Parameter wurden effizient verteilt, um Speichernutzung und Recheneffizienz zu optimieren

Fazit

  • Die Erfahrung mit der Feinabstimmung des Modells LLaMA 3.1 405B auf AMD-GPUs mit JAX war sehr positiv
  • Mithilfe der leistungsstarken Parallelisierungsfunktionen von JAX und seines hardwareunabhängigen Ansatzes wurde das Modell effizient verteilt
  • Es wurde gezeigt, dass AMD-GPUs eine starke Alternative für das Training großer KI-Modelle sind
  • Der vollständige Code kann im GitHub-Repository eingesehen und direkt ausgeführt werden

Zusammenfassung von GN⁺

  • Dieser Artikel erklärt, wie sich große KI-Modelle mit AMD-GPUs und JAX effizient trainieren lassen
  • Es wird hervorgehoben, dass AMD-Hardware im Vergleich zu NVIDIA eine kosteneffiziente Alternative ist
  • Der hardwareunabhängige Ansatz von JAX erhöht die Portierbarkeit des Codes und erleichtert die Wartung
  • Der Artikel bietet nützliche Informationen und Praxiscode für alle, die sich für das Training großer Modelle interessieren
  • Projekte mit ähnlichen Funktionen sind unter anderem NVIDIA CUDA und PyTorch

1 Kommentare

 
GN⁺ 2024-09-24
Hacker-News-Kommentare
  • Ergebnisse beim Fine-Tuning des Modells Llama3.1 405B mit JAX auf 8xAMD MI300x GPUs geteilt

    • Dank der fortschrittlichen Sharding-API von JAX wurde eine hervorragende Performance erreicht
    • Links zum Blogpost und zum Open-Source-Code bereitgestellt: GitHub-Link
    • Es handelt sich um ein Startup, das eine KI-Infrastruktur zum Fine-Tuning und Bereitstellen von LLMs auf TPU, AMD und Trainium statt auf NVIDIA-Hardware aufbaut
    • Viele Unternehmen versuchen, PyTorch auf AMD GPUs zum Laufen zu bringen, aber das wird als schwieriger Weg eingeschätzt
    • PyTorch ist eng mit dem NVIDIA-Ökosystem verbunden, sodass viele Anpassungen nötig sind, um es auf Nicht-NVIDIA-Hardware zu betreiben
    • Es wird angenommen, dass JAX besser für Nicht-NVIDIA-Hardware geeignet ist
    • In JAX wird ML-Modellcode in hardwareunabhängige HLO-Graphen kompiliert, und der XLA-Compiler führt hardwarespezifische Optimierungen aus
    • Derselbe JAX-Code kann ohne Änderungen sowohl auf Google TPU als auch auf AMD GPUs laufen
    • Die Unternehmensstrategie besteht darin, Modelle nach JAX zu portieren und mit XLA-Kernels maximale Performance aus Nicht-NVIDIA-Backends herauszuholen
    • Llama 3.1 wurde zunächst von PyTorch nach JAX portiert, und nun funktioniert dasselbe JAX-Modell sowohl auf TPUs als auch auf AMD GPUs gut
    • Es wird um Feedback zur Vision und zum Repository gebeten
  • Vorschlag, Wege zu erkunden, um Speicherbeschränkungen zu überwinden und eine JIT-kompilierte Version auszuführen

    • Das könnte zusätzliche Performance-Verbesserungen bringen
  • Erfahrungen mit AMD GPUs und ROCm-Support geteilt

    • Vor einem Jahr wurde versucht, AMD GPUs und ROCm-Support zu nutzen, aber es entstand der Eindruck, dass AMD noch weit davon entfernt ist, NVIDIA einzuholen
    • Die Wahl von JAX ist ein interessanter Ansatz, aber es wird gefragt, welche Schwierigkeiten es beim Verlassen von PyTorch gab
  • Erfahrungen mit Experimenten zur Inferenz des 405B-Modells geteilt

    • torch.cuda sei ihrer Meinung nach gar nicht so schlecht
    • Da die AMD-Version von PyTorch dies übersetzt, wird es als bloßes Namensproblem betrachtet
    • Die Verwendung des rocm:pytorch-Containers sei genauso einfach wie die des rocm:jax-Containers
    • Es wird darauf hingewiesen, dass kaum Performance-Daten veröffentlicht wurden
    • Es wird nach den MFU-Werten (Model Flops Utilization) gefragt
  • Frage nach dem Fehlen von Performance-Daten

    • Es wird infrage gestellt, ob sich durch Großbestellungen von AMD GPUs überhaupt ein Mehrwert herausholen lässt
    • Der Eindruck lautet: „Nein“
  • Frage, warum Obsidian (Notiz-App) das tut

    • Zunächst wurde angenommen, es handle sich um einen Beitrag von Obsidian
    • Es wird gefragt, warum GitHub.com und GitHub.io noch immer nicht unterschieden werden
  • Bitte an @dang, den Benutzernamen in die URL aufzunehmen

    • Dieser Beitrag betrifft keinen offiziellen Beitrag von Obsidian selbst, sondern einen nutzergenerierten Blog