2 Punkte von GN⁺ 2024-09-24 | 1 Kommentare | Auf WhatsApp teilen

Felafax BlogTune Llama3 405B auf AMD MI300x (unsere Reise)

Einführung

  • Mit dem Wachstum von Open-Source-Modellen steigt der Bedarf an leistungsfähiger Infrastruktur für das Training großer KI-Modelle
  • Felafax hat das Modell LLaMA 3.1 405B auf AMD-GPUs feinabgestimmt und damit die Effizienz von AMD-Hardware demonstriert
  • Die gesamte Arbeit wurde als Open Source auf GitHub veröffentlicht
  • AMD-MI300X-GPUs bieten im Vergleich zu NVIDIA-AI-Hardware eine hohe Leistung
  • Das Projekt wurde durch die Unterstützung von TensorWave ermöglicht

Was ist JAX und warum wurde es gewählt?

  • JAX ist eine leistungsstarke Machine-Learning-Bibliothek, die eine NumPy-ähnliche API, automatische Differenzierung und Googles XLA-Compiler kombiniert
  • Es bietet hervorragende APIs für Modellparallelisierung und ist daher ideal für das Training großer Modelle

Vorteile von JAX

  • Reine Funktionen: JAX fördert das Schreiben reiner Funktionen, wodurch sich Code leichter strukturieren, debuggen und lesen lässt
  • Fortgeschrittene Parallelisierung: Die flexible JIT-API von JAX unterstützt fortgeschrittene Daten- und Modellparallelisierung, die für groß angelegtes Training essenziell ist
  • Saubere Codebasis: Die Designphilosophie von JAX fördert das Schreiben von Code, der zwischen Hardwareplattformen portierbar ist

Warum JAX auf nicht von NVIDIA stammender Hardware herausragt

  • Hardwareunabhängiger Ansatz: JAX nutzt den XLA-Compiler, um Berechnungen in eine hardwareunabhängige Zwischenrepräsentation zu kompilieren
  • Plattformunabhängige Optimierung: Der XLA-Compiler führt Optimierungen unabhängig von der Hardware durch
  • Einfache Portierbarkeit: Mit JAX sind beim Wechsel von NVIDIA zu AMD nur minimale Codeänderungen nötig

JAX auf AMD-GPUs einrichten

  • Ein Docker-Image wurde geladen, der Container gestartet und anschließend die Installation überprüft
  • Das Modell LLaMA 405B wurde mit 8 AMD-MI300X-GPUs trainiert

Training von LLaMA 405B: Leistung und Skalierbarkeit

  • Das Modell LLaMA 405B wurde mit JAX auf AMD-GPUs trainiert
  • Beim LoRA-Feintuning wurden Modellgewichte und LoRA-Parameter mit bfloat16-Präzision angepasst
  • Modellgröße: belegt etwa 800 GB VRAM
  • LoRA-Gewichte und Optimizer-Zustände: belegen etwa 400 GB VRAM
  • Gesamte VRAM-Nutzung: etwa 1200 GB
  • Trainingsgeschwindigkeit: etwa 35 Token pro Sekunde
  • Speichereffizienz: hält etwa 70 %
  • Skalierbarkeit: mit JAX nahezu linear über 8 GPUs skalierbar

Unser Trainings-Setup

  • LLaMA 3.1 wurde von PyTorch nach JAX konvertiert
  • Das Modell wurde durch Modellladen und Parameter-Sharding effizient verteilt

Parameter-Sharding in JAX

  • Mithilfe der Device-Mesh-Funktion von JAX wurde das Modell effizient auf 8 AMD-GPUs verteilt
  • Durch die Definition von Parameter-Sharding-Regeln wurden die Dimensionen jedes Tensors entlang der Mesh-Achsen geshardet

Implementierung des LoRA-Trainings

  • LoRA reduziert die Zahl trainierbarer Parameter, indem Gewichtsaktualisierungen in niedrig-rangige Matrizen zerlegt werden
  • Es wurde eine LoRADense-Schicht implementiert, die LoRA-Parameter enthält
  • LoRA-Parameter wurden effizient verteilt, um Speichernutzung und Recheneffizienz zu optimieren

Fazit

  • Die Erfahrung mit der Feinabstimmung des Modells LLaMA 3.1 405B auf AMD-GPUs mit JAX war sehr positiv
  • Mithilfe der leistungsstarken Parallelisierungsfunktionen von JAX und seines hardwareunabhängigen Ansatzes wurde das Modell effizient verteilt
  • Es wurde gezeigt, dass AMD-GPUs eine starke Alternative für das Training großer KI-Modelle sind
  • Der vollständige Code kann im GitHub-Repository eingesehen und direkt ausgeführt werden

Zusammenfassung von GN⁺

  • Dieser Artikel erklärt, wie sich große KI-Modelle mit AMD-GPUs und JAX effizient trainieren lassen
  • Es wird hervorgehoben, dass AMD-Hardware im Vergleich zu NVIDIA eine kosteneffiziente Alternative ist
  • Der hardwareunabhängige Ansatz von JAX erhöht die Portierbarkeit des Codes und erleichtert die Wartung
  • Der Artikel bietet nützliche Informationen und Praxiscode für alle, die sich für das Training großer Modelle interessieren
  • Projekte mit ähnlichen Funktionen sind unter anderem NVIDIA CUDA und PyTorch

1 Kommentare

 
GN⁺ 2024-09-24
Hacker-News-Kommentare
  • Kürzlich wurde das Modell llama3.1 405B auf 8x AMD MI300x GPUs mit JAX statt PyTorch feinabgestimmt.
    Dank der fortgeschrittenen Sharding-APIs von JAX war die Performance gut; die verwendeten Sharding-Techniken sind im Blog zusammengefasst. Der Code ist ebenfalls veröffentlicht: https://github.com/felafax/felafax
    Wir sind ein kleines Startup, das AI-Infrastruktur für Fine-Tuning und Serving von LLMs auf Nicht-NVIDIA-Hardware (TPU, AMD, Trainium) baut.
    Viele Unternehmen versuchen, PyTorch auf AMD GPUs laufen zu lassen, aber PyTorch ist mit Dingen wie torch.cuda oder scaled_dot_product_attention tief mit dem NVIDIA-Ökosystem verflochten, sodass unserer Ansicht nach viel „Ent-NVIDIA-isierung“ nötig ist.
    Wir denken, dass JAX besser zu Nicht-NVIDIA-Hardware passt, weil Modellcode in hardwareunabhängige HLO-Graphen kompiliert wird, anschließend der XLA-Compiler optimiert und danach hardwarespezifische Optimierungen angewendet werden. Derselbe LLaMA3-JAX-Code lief unverändert auf Google TPU und AMD GPU.
    Die Strategie des Unternehmens ist, Modelle zunächst nach JAX zu portieren und dann das JAX-Framework sowie XLA-Kernel zu nutzen, um auf Nicht-NVIDIA-Backends maximale Performance herauszuholen. Deshalb haben wir Llama 3.1 zuerst von PyTorch nach JAX übertragen, und dasselbe JAX-Modell läuft gut auf TPU und AMD GPU.

    • Es gab keine größeren Probleme, PyTorch auf AMD GPUs ohne Änderungen am CUDA-Code laufen zu lassen. Auch der Blog von MosaicML ist einen Blick wert: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • Mich würde interessieren, wie ihr die Genauigkeit der JAX-Portierung von Llama 3.1 validiert.
      Persönlich ist der Hauptgrund, warum ich PyTorch verwende, dass das ursprüngliche Modell in PyTorch erstellt wurde. Auch wenn die Logik zwischen verschiedenen Modellversionen gleich aussieht, können sich bei enormen Datenmengen winzige Floating-Point-Abweichungen aufsummieren und zu Model Drift führen.
      Solche Genauigkeitsabweichungen bei großen Modellen zu debuggen, kam mir schlimmer vor als der zehnte Kreis der Hölle.
    • Mich würde interessieren, ob JAX eigene Implementierungen für Matrixmultiplikation oder FlashAttention hat oder ob es wie PyTorch ROCm-Implementierungen nutzt, etwa hipblaslt oder Composable Kernel FA.
      Ich kenne JAX nicht besonders gut, aber ein erheblicher Teil der miserablen PyTorch-Trainingsperformance auf MI300x scheint mir daran zu liegen, dass die intern verwendeten ROCm-Bibliotheken langsam sind.
    • Mich würde interessieren, ob es auch auf Consumer-Karten wie der 7900 XTX funktioniert.
      Mit funktionieren meine ich hier nicht, dass man zwei Wochen damit verbringt, die Treiber zum Laufen zu bekommen, und den Server danach nie wieder aktualisieren kann.
    • Wenn es um eine Migration geht: Gibt es konkrete Zahlen im Vergleich zur PyTorch-Version desselben Modells? Die Vergleichstabelle im Artikel wirkt eher technisch.
      Mich würden auch die technischen Probleme interessieren, auf die ihr gestoßen seid.
  • Um es klar zu sagen: Diese Performance ist ziemlich schlecht. Vermutlich liegt das daran, dass Kompilierung nicht richtig zum Laufen gebracht wurde.
    Beim 405B-Modell werden 35 Token/s erreicht, was etwa 85 Teraflops entspricht. 8 MI300x GPUs liegen bei rund 10,4 Petaflops, also beträgt die MFU etwa 0,8 %.
    Das ist 40- bis 50-mal niedriger als ordentliche Trainingsperformance von 30–40 % MFU; aus AMDs Sicht wäre wohl zu hoffen, dass der Software-Stack der Flaschenhals ist.

    • Genau das wollte ich auch fragen.
      Auf der GitHub-Seite heißt es, man könne LLaMa3.1 auf Google Cloud TPU zu 30 % niedrigeren Kosten tunen, aber Performance wird nicht erwähnt.
  • Großartige Arbeit. Vor etwa einem Jahr habe ich ein wenig mit AMD GPUs und ROCm-Support herumgespielt, und es war klar, dass AMD noch einen langen Weg vor sich hat, um zu Nvidia aufzuschließen.
    Der Ansatz mit JAX ist interessant; mich würde interessieren, welche Schwierigkeiten es gab, sich von PyTorch zu lösen, das ja fast schon die Standardbibliothek für Machine Learning ist.

    • Vor ein paar Wochen haben wir ein Show HN zu unserer Reise gepostet: https://news.ycombinator.com/item?id=41512142
      Anfangs war das Ziel, LLaMA 3 auf TPU feinabzustimmen, aber PyTorch XLA war sperrig, daher haben wir beschlossen, das Modell in JAX neu zu schreiben.
      Wie gesagt sehen wir JAX als bessere Plattform für Nicht-NVIDIA-GPUs und möchten auf JAX+openXLA Infrastruktur für Nicht-NVIDIA-GPUs aufbauen.
    • Auf meinem Debian 12-System bekomme ich AMD ROCm nicht zum Laufen, und deshalb scheint Ollama statt der GPU die CPU zu verwenden. Da scheint noch ein langer Weg vor uns zu liegen.
  • Gute Arbeit. Am vergangenen Wochenende habe ich selbst an der Inference-Seite von 405B herumprobiert [0].
    Ich bin nicht überzeugt, dass torch.cuda so schlimm ist, denn PyTorch für AMD übersetzt das entsprechend. Das ist eher ein Namensproblem als ein grundsätzliches.
    Tatsächlich ist es genauso einfach, einen rocm:pytorch-Container zu ziehen wie einen rocm:jax-Container.
    Es wurden nicht viele Zahlen veröffentlicht; mich würde interessieren, welche MFU erreicht wurde.
    [0] https://x.com/HotAisle/status/1837580046732874026

    • Gut.
      Die MFU müssen wir berechnen. Details zu GPU und VRAM sind im Repository zu sehen: https://dub.sh/amd-405b-res
      Am nächsten Wochenende planen wir, den Trainingslauf erneut zu versuchen, dabei den gesamten Trainingsschritt per JIT zu kompilieren und dann die MFU zu berechnen.
  • Bei unseren Messungen bei ZML war MI300X 30 % schneller als H100. Das sind hervorragende Chips.

  • Mich würde interessieren, ob es einen Cloud-Anbieter gibt, bei dem man 8xAMD-MI300-Hosts mieten kann.
    Beruflich nutze ich viel AWS, wollte aber einmal AMD GPUs ausprobieren.

    • Zur Info: Unser Unternehmen vermietet 8xMI300x, du kannst dich also gern melden.
    • Oracle bietet das an. Andere werden wahrscheinlich nachziehen, aber kleinere Anbieter dürften vernünftiger im Umgang sein.
  • Wo sind die Performance-Daten?

    • Ich habe Daten zur GPU- und VRAM-Auslastung ins GitHub-Repository aufgenommen: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      Wegen Code- und VRAM-Beschränkungen konnten wir die JIT-kompilierte Version des 405B-Modells nicht ausführen. Das müssen wir weiter untersuchen.
      Der vollständige Trainingslauf wurde im JAX-Eager-Modus durchgeführt, daher gibt es viel Raum für Performance-Verbesserungen.
      Selbst im Eager-Modus lag die GPU-Auslastung insgesamt bei etwa 30–40 %, was ziemlich ordentlich ist. Mit JIT dürfte sich die GPU-Auslastung unserer Ansicht nach leicht auf 50–60 % erhöhen lassen.
  • Wenn möglich, wäre es interessant zu untersuchen, wie man die Speicherbeschränkungen überwinden und die JIT-kompilierte Version ausführen kann. Das könnte weitere Performance-Verbesserungen bringen.

    • Stimme zu. Da ist noch viel Performance herauszuholen.
      Wir brauchen einen JIT-kompilierten Trainingsschritt, stärker optimiertes Data Loading und Sharding, Gradientenakkumulation sowie Activation Checkpointing.
      Wir bauen weiter, implementieren alle Verbesserungen und werden bald wieder einen Blogpost veröffentlichen.
  • Ich frage mich, ob AMD auch nur ein Stück näher daran ist, hier durch GPU-Großaufträge und Lieferengpässe Wert abzuschöpfen.
    Mein Eindruck geht eher in Richtung „nein“.

    • Ich verstehe den Sarkasmus. Aber wenn man zum jetzigen Zeitpunkt nicht die gesamte Hardware und Software für AI einem einzigen Anbieter überlassen will, muss man anfangen, sich auf Alternativen zuzubewegen.
      Der Gegner hat einen enormen Vorsprung, und auf der Softwareseite gibt es eindeutig viel zu tun. Das braucht Zeit.
  • Warum macht Obsidian, eine Notiz-App, so etwas?

    • Tut sie nicht. Dieses Unternehmen nutzt Obsidian Publish für die Veröffentlichung seiner Dokumentation.