Felafax BlogTune Llama3 405B auf AMD MI300x (unsere Reise)
Einführung
- Mit dem Wachstum von Open-Source-Modellen steigt der Bedarf an leistungsfähiger Infrastruktur für das Training großer KI-Modelle
- Felafax hat das Modell LLaMA 3.1 405B auf AMD-GPUs feinabgestimmt und damit die Effizienz von AMD-Hardware demonstriert
- Die gesamte Arbeit wurde als Open Source auf GitHub veröffentlicht
- AMD-MI300X-GPUs bieten im Vergleich zu NVIDIA-AI-Hardware eine hohe Leistung
- Das Projekt wurde durch die Unterstützung von TensorWave ermöglicht
Was ist JAX und warum wurde es gewählt?
- JAX ist eine leistungsstarke Machine-Learning-Bibliothek, die eine NumPy-ähnliche API, automatische Differenzierung und Googles XLA-Compiler kombiniert
- Es bietet hervorragende APIs für Modellparallelisierung und ist daher ideal für das Training großer Modelle
Vorteile von JAX
- Reine Funktionen: JAX fördert das Schreiben reiner Funktionen, wodurch sich Code leichter strukturieren, debuggen und lesen lässt
- Fortgeschrittene Parallelisierung: Die flexible JIT-API von JAX unterstützt fortgeschrittene Daten- und Modellparallelisierung, die für groß angelegtes Training essenziell ist
- Saubere Codebasis: Die Designphilosophie von JAX fördert das Schreiben von Code, der zwischen Hardwareplattformen portierbar ist
Warum JAX auf nicht von NVIDIA stammender Hardware herausragt
- Hardwareunabhängiger Ansatz: JAX nutzt den XLA-Compiler, um Berechnungen in eine hardwareunabhängige Zwischenrepräsentation zu kompilieren
- Plattformunabhängige Optimierung: Der XLA-Compiler führt Optimierungen unabhängig von der Hardware durch
- Einfache Portierbarkeit: Mit JAX sind beim Wechsel von NVIDIA zu AMD nur minimale Codeänderungen nötig
JAX auf AMD-GPUs einrichten
- Ein Docker-Image wurde geladen, der Container gestartet und anschließend die Installation überprüft
- Das Modell LLaMA 405B wurde mit 8 AMD-MI300X-GPUs trainiert
Training von LLaMA 405B: Leistung und Skalierbarkeit
- Das Modell LLaMA 405B wurde mit JAX auf AMD-GPUs trainiert
- Beim LoRA-Feintuning wurden Modellgewichte und LoRA-Parameter mit bfloat16-Präzision angepasst
- Modellgröße: belegt etwa 800 GB VRAM
- LoRA-Gewichte und Optimizer-Zustände: belegen etwa 400 GB VRAM
- Gesamte VRAM-Nutzung: etwa 1200 GB
- Trainingsgeschwindigkeit: etwa 35 Token pro Sekunde
- Speichereffizienz: hält etwa 70 %
- Skalierbarkeit: mit JAX nahezu linear über 8 GPUs skalierbar
Unser Trainings-Setup
- LLaMA 3.1 wurde von PyTorch nach JAX konvertiert
- Das Modell wurde durch Modellladen und Parameter-Sharding effizient verteilt
Parameter-Sharding in JAX
- Mithilfe der Device-Mesh-Funktion von JAX wurde das Modell effizient auf 8 AMD-GPUs verteilt
- Durch die Definition von Parameter-Sharding-Regeln wurden die Dimensionen jedes Tensors entlang der Mesh-Achsen geshardet
Implementierung des LoRA-Trainings
- LoRA reduziert die Zahl trainierbarer Parameter, indem Gewichtsaktualisierungen in niedrig-rangige Matrizen zerlegt werden
- Es wurde eine
LoRADense-Schicht implementiert, die LoRA-Parameter enthält
- LoRA-Parameter wurden effizient verteilt, um Speichernutzung und Recheneffizienz zu optimieren
Fazit
- Die Erfahrung mit der Feinabstimmung des Modells LLaMA 3.1 405B auf AMD-GPUs mit JAX war sehr positiv
- Mithilfe der leistungsstarken Parallelisierungsfunktionen von JAX und seines hardwareunabhängigen Ansatzes wurde das Modell effizient verteilt
- Es wurde gezeigt, dass AMD-GPUs eine starke Alternative für das Training großer KI-Modelle sind
- Der vollständige Code kann im GitHub-Repository eingesehen und direkt ausgeführt werden
Zusammenfassung von GN⁺
- Dieser Artikel erklärt, wie sich große KI-Modelle mit AMD-GPUs und JAX effizient trainieren lassen
- Es wird hervorgehoben, dass AMD-Hardware im Vergleich zu NVIDIA eine kosteneffiziente Alternative ist
- Der hardwareunabhängige Ansatz von JAX erhöht die Portierbarkeit des Codes und erleichtert die Wartung
- Der Artikel bietet nützliche Informationen und Praxiscode für alle, die sich für das Training großer Modelle interessieren
- Projekte mit ähnlichen Funktionen sind unter anderem NVIDIA CUDA und PyTorch
1 Kommentare
Hacker-News-Kommentare
Ergebnisse beim Fine-Tuning des Modells Llama3.1 405B mit JAX auf 8xAMD MI300x GPUs geteilt
Vorschlag, Wege zu erkunden, um Speicherbeschränkungen zu überwinden und eine JIT-kompilierte Version auszuführen
Erfahrungen mit AMD GPUs und ROCm-Support geteilt
Erfahrungen mit Experimenten zur Inferenz des 405B-Modells geteilt
torch.cudasei ihrer Meinung nach gar nicht so schlechtrocm:pytorch-Containers sei genauso einfach wie die desrocm:jax-ContainersFrage nach dem Fehlen von Performance-Daten
Frage, warum Obsidian (Notiz-App) das tut
Bitte an @dang, den Benutzernamen in die URL aufzunehmen