Feinabstimmung von Llama 405B auf AMD-GPUs
(publish.obsidian.md)Felafax BlogTune Llama3 405B auf AMD MI300x (unsere Reise)
Einführung
- Mit dem Wachstum von Open-Source-Modellen steigt der Bedarf an leistungsfähiger Infrastruktur für das Training großer KI-Modelle
- Felafax hat das Modell LLaMA 3.1 405B auf AMD-GPUs feinabgestimmt und damit die Effizienz von AMD-Hardware demonstriert
- Die gesamte Arbeit wurde als Open Source auf GitHub veröffentlicht
- AMD-MI300X-GPUs bieten im Vergleich zu NVIDIA-AI-Hardware eine hohe Leistung
- Das Projekt wurde durch die Unterstützung von TensorWave ermöglicht
Was ist JAX und warum wurde es gewählt?
- JAX ist eine leistungsstarke Machine-Learning-Bibliothek, die eine NumPy-ähnliche API, automatische Differenzierung und Googles XLA-Compiler kombiniert
- Es bietet hervorragende APIs für Modellparallelisierung und ist daher ideal für das Training großer Modelle
Vorteile von JAX
- Reine Funktionen: JAX fördert das Schreiben reiner Funktionen, wodurch sich Code leichter strukturieren, debuggen und lesen lässt
- Fortgeschrittene Parallelisierung: Die flexible JIT-API von JAX unterstützt fortgeschrittene Daten- und Modellparallelisierung, die für groß angelegtes Training essenziell ist
- Saubere Codebasis: Die Designphilosophie von JAX fördert das Schreiben von Code, der zwischen Hardwareplattformen portierbar ist
Warum JAX auf nicht von NVIDIA stammender Hardware herausragt
- Hardwareunabhängiger Ansatz: JAX nutzt den XLA-Compiler, um Berechnungen in eine hardwareunabhängige Zwischenrepräsentation zu kompilieren
- Plattformunabhängige Optimierung: Der XLA-Compiler führt Optimierungen unabhängig von der Hardware durch
- Einfache Portierbarkeit: Mit JAX sind beim Wechsel von NVIDIA zu AMD nur minimale Codeänderungen nötig
JAX auf AMD-GPUs einrichten
- Ein Docker-Image wurde geladen, der Container gestartet und anschließend die Installation überprüft
- Das Modell LLaMA 405B wurde mit 8 AMD-MI300X-GPUs trainiert
Training von LLaMA 405B: Leistung und Skalierbarkeit
- Das Modell LLaMA 405B wurde mit JAX auf AMD-GPUs trainiert
- Beim LoRA-Feintuning wurden Modellgewichte und LoRA-Parameter mit bfloat16-Präzision angepasst
- Modellgröße: belegt etwa 800 GB VRAM
- LoRA-Gewichte und Optimizer-Zustände: belegen etwa 400 GB VRAM
- Gesamte VRAM-Nutzung: etwa 1200 GB
- Trainingsgeschwindigkeit: etwa 35 Token pro Sekunde
- Speichereffizienz: hält etwa 70 %
- Skalierbarkeit: mit JAX nahezu linear über 8 GPUs skalierbar
Unser Trainings-Setup
- LLaMA 3.1 wurde von PyTorch nach JAX konvertiert
- Das Modell wurde durch Modellladen und Parameter-Sharding effizient verteilt
Parameter-Sharding in JAX
- Mithilfe der Device-Mesh-Funktion von JAX wurde das Modell effizient auf 8 AMD-GPUs verteilt
- Durch die Definition von Parameter-Sharding-Regeln wurden die Dimensionen jedes Tensors entlang der Mesh-Achsen geshardet
Implementierung des LoRA-Trainings
- LoRA reduziert die Zahl trainierbarer Parameter, indem Gewichtsaktualisierungen in niedrig-rangige Matrizen zerlegt werden
- Es wurde eine
LoRADense-Schicht implementiert, die LoRA-Parameter enthält - LoRA-Parameter wurden effizient verteilt, um Speichernutzung und Recheneffizienz zu optimieren
Fazit
- Die Erfahrung mit der Feinabstimmung des Modells LLaMA 3.1 405B auf AMD-GPUs mit JAX war sehr positiv
- Mithilfe der leistungsstarken Parallelisierungsfunktionen von JAX und seines hardwareunabhängigen Ansatzes wurde das Modell effizient verteilt
- Es wurde gezeigt, dass AMD-GPUs eine starke Alternative für das Training großer KI-Modelle sind
- Der vollständige Code kann im GitHub-Repository eingesehen und direkt ausgeführt werden
Zusammenfassung von GN⁺
- Dieser Artikel erklärt, wie sich große KI-Modelle mit AMD-GPUs und JAX effizient trainieren lassen
- Es wird hervorgehoben, dass AMD-Hardware im Vergleich zu NVIDIA eine kosteneffiziente Alternative ist
- Der hardwareunabhängige Ansatz von JAX erhöht die Portierbarkeit des Codes und erleichtert die Wartung
- Der Artikel bietet nützliche Informationen und Praxiscode für alle, die sich für das Training großer Modelle interessieren
- Projekte mit ähnlichen Funktionen sind unter anderem NVIDIA CUDA und PyTorch
1 Kommentare
Hacker-News-Kommentare
Kürzlich wurde das Modell llama3.1 405B auf 8x AMD MI300x GPUs mit JAX statt PyTorch feinabgestimmt.
Dank der fortgeschrittenen Sharding-APIs von JAX war die Performance gut; die verwendeten Sharding-Techniken sind im Blog zusammengefasst. Der Code ist ebenfalls veröffentlicht: https://github.com/felafax/felafax
Wir sind ein kleines Startup, das AI-Infrastruktur für Fine-Tuning und Serving von LLMs auf Nicht-NVIDIA-Hardware (TPU, AMD, Trainium) baut.
Viele Unternehmen versuchen, PyTorch auf AMD GPUs laufen zu lassen, aber PyTorch ist mit Dingen wie
torch.cudaoderscaled_dot_product_attentiontief mit dem NVIDIA-Ökosystem verflochten, sodass unserer Ansicht nach viel „Ent-NVIDIA-isierung“ nötig ist.Wir denken, dass JAX besser zu Nicht-NVIDIA-Hardware passt, weil Modellcode in hardwareunabhängige HLO-Graphen kompiliert wird, anschließend der XLA-Compiler optimiert und danach hardwarespezifische Optimierungen angewendet werden. Derselbe LLaMA3-JAX-Code lief unverändert auf Google TPU und AMD GPU.
Die Strategie des Unternehmens ist, Modelle zunächst nach JAX zu portieren und dann das JAX-Framework sowie XLA-Kernel zu nutzen, um auf Nicht-NVIDIA-Backends maximale Performance herauszuholen. Deshalb haben wir Llama 3.1 zuerst von PyTorch nach JAX übertragen, und dasselbe JAX-Modell läuft gut auf TPU und AMD GPU.
Persönlich ist der Hauptgrund, warum ich PyTorch verwende, dass das ursprüngliche Modell in PyTorch erstellt wurde. Auch wenn die Logik zwischen verschiedenen Modellversionen gleich aussieht, können sich bei enormen Datenmengen winzige Floating-Point-Abweichungen aufsummieren und zu Model Drift führen.
Solche Genauigkeitsabweichungen bei großen Modellen zu debuggen, kam mir schlimmer vor als der zehnte Kreis der Hölle.
hipblasltoder Composable Kernel FA.Ich kenne JAX nicht besonders gut, aber ein erheblicher Teil der miserablen PyTorch-Trainingsperformance auf MI300x scheint mir daran zu liegen, dass die intern verwendeten ROCm-Bibliotheken langsam sind.
Mit funktionieren meine ich hier nicht, dass man zwei Wochen damit verbringt, die Treiber zum Laufen zu bekommen, und den Server danach nie wieder aktualisieren kann.
Mich würden auch die technischen Probleme interessieren, auf die ihr gestoßen seid.
Um es klar zu sagen: Diese Performance ist ziemlich schlecht. Vermutlich liegt das daran, dass Kompilierung nicht richtig zum Laufen gebracht wurde.
Beim 405B-Modell werden 35 Token/s erreicht, was etwa 85 Teraflops entspricht. 8 MI300x GPUs liegen bei rund 10,4 Petaflops, also beträgt die MFU etwa 0,8 %.
Das ist 40- bis 50-mal niedriger als ordentliche Trainingsperformance von 30–40 % MFU; aus AMDs Sicht wäre wohl zu hoffen, dass der Software-Stack der Flaschenhals ist.
Auf der GitHub-Seite heißt es, man könne LLaMa3.1 auf Google Cloud TPU zu 30 % niedrigeren Kosten tunen, aber Performance wird nicht erwähnt.
Großartige Arbeit. Vor etwa einem Jahr habe ich ein wenig mit AMD GPUs und ROCm-Support herumgespielt, und es war klar, dass AMD noch einen langen Weg vor sich hat, um zu Nvidia aufzuschließen.
Der Ansatz mit JAX ist interessant; mich würde interessieren, welche Schwierigkeiten es gab, sich von PyTorch zu lösen, das ja fast schon die Standardbibliothek für Machine Learning ist.
Anfangs war das Ziel, LLaMA 3 auf TPU feinabzustimmen, aber PyTorch XLA war sperrig, daher haben wir beschlossen, das Modell in JAX neu zu schreiben.
Wie gesagt sehen wir JAX als bessere Plattform für Nicht-NVIDIA-GPUs und möchten auf JAX+openXLA Infrastruktur für Nicht-NVIDIA-GPUs aufbauen.
Gute Arbeit. Am vergangenen Wochenende habe ich selbst an der Inference-Seite von 405B herumprobiert [0].
Ich bin nicht überzeugt, dass
torch.cudaso schlimm ist, denn PyTorch für AMD übersetzt das entsprechend. Das ist eher ein Namensproblem als ein grundsätzliches.Tatsächlich ist es genauso einfach, einen
rocm:pytorch-Container zu ziehen wie einenrocm:jax-Container.Es wurden nicht viele Zahlen veröffentlicht; mich würde interessieren, welche MFU erreicht wurde.
[0] https://x.com/HotAisle/status/1837580046732874026
Die MFU müssen wir berechnen. Details zu GPU und VRAM sind im Repository zu sehen: https://dub.sh/amd-405b-res
Am nächsten Wochenende planen wir, den Trainingslauf erneut zu versuchen, dabei den gesamten Trainingsschritt per JIT zu kompilieren und dann die MFU zu berechnen.
Bei unseren Messungen bei ZML war MI300X 30 % schneller als H100. Das sind hervorragende Chips.
Mich würde interessieren, ob es einen Cloud-Anbieter gibt, bei dem man 8xAMD-MI300-Hosts mieten kann.
Beruflich nutze ich viel AWS, wollte aber einmal AMD GPUs ausprobieren.
Wo sind die Performance-Daten?
Wegen Code- und VRAM-Beschränkungen konnten wir die JIT-kompilierte Version des 405B-Modells nicht ausführen. Das müssen wir weiter untersuchen.
Der vollständige Trainingslauf wurde im JAX-Eager-Modus durchgeführt, daher gibt es viel Raum für Performance-Verbesserungen.
Selbst im Eager-Modus lag die GPU-Auslastung insgesamt bei etwa 30–40 %, was ziemlich ordentlich ist. Mit JIT dürfte sich die GPU-Auslastung unserer Ansicht nach leicht auf 50–60 % erhöhen lassen.
Wenn möglich, wäre es interessant zu untersuchen, wie man die Speicherbeschränkungen überwinden und die JIT-kompilierte Version ausführen kann. Das könnte weitere Performance-Verbesserungen bringen.
Wir brauchen einen JIT-kompilierten Trainingsschritt, stärker optimiertes Data Loading und Sharding, Gradientenakkumulation sowie Activation Checkpointing.
Wir bauen weiter, implementieren alle Verbesserungen und werden bald wieder einen Blogpost veröffentlichen.
Ich frage mich, ob AMD auch nur ein Stück näher daran ist, hier durch GPU-Großaufträge und Lieferengpässe Wert abzuschöpfen.
Mein Eindruck geht eher in Richtung „nein“.
Der Gegner hat einen enormen Vorsprung, und auf der Softwareseite gibt es eindeutig viel zu tun. Das braucht Zeit.
Warum macht Obsidian, eine Notiz-App, so etwas?