PyTorch ist tot. Es lebe JAX
(neel04.github.io)- Der Grund, warum PyTorch zu Produktivitätsverlusten und verschwendeter Entwicklungszeit führt, ist „nicht, dass das Framework selbst schlecht ist, sondern dass es nicht für die Use Cases entworfen wurde, auf die es heute angewendet wird“
Die Philosophie von PyTorch
- Die Philosophie von PyTorch ist dynamisch, leicht zu debuggen und pythonisch
- TensorFlow 1.x dagegen wollte mit starker Nutzung des XLA-Compilers ein statisches, aber performantes Framework sein
- Die TensorFlow-Entwickler erkannten, dass die Community die 1.x-API nicht mochte, entschieden sich für Keras als Hauptschnittstelle und reduzierten die Rolle des XLA-Compilers
- PyTorch blieb seinen Wurzeln treu und setzte im Gegensatz zum statischen und verzögerten Ansatz von TensorFlow auf einen dynamischeren „Eager Execution“-Ansatz, bei dem
torch.Tensorsofort ausgewertet wird - Das war erfolgreich, und viel Forschung wechselte zu PyTorch
- Mit dem Aufkommen von GPT-3 im Jahr 2021 wurden Performance und Skalierbarkeit zu zentralen Themen
- PyTorch reagierte auf diese Nachfrage bis zu einem gewissen Grad gut, aber da es nicht mit dieser Philosophie im Hinterkopf entworfen wurde, häuften sich zunehmend Altlasten an und das Fundament begann zu wanken
- Die PyTorch-Entwickler wollten keinerlei Kompromisse und entschieden sich, zwei Wege gleichzeitig zu verfolgen
- Einsatz des XLA-Compilers als Standard-Backend mit hoher Performance und Stabilität
- Aufbau des
torch.compile-Stacks, damit Nutzer bei Bedarf den Compiler aufrufen können
- Das Fehlen einer langfristigen Strategie ist ein ernstes Problem
- PyTorch möchte sich nicht auf eine compilerzentrierte Philosophie wie JAX festlegen, aber eine gute Alternative ist nicht in Sicht
- Wie lösen Konkurrenzprodukte dieses Problem?
Compilerbasierte Entwicklung in JAX
- JAX nutzt XLA, den leistungsfähigen Compiler-Stack von TensorFlow
- XLA ist ein mächtiger Compiler, für Endnutzer aber vollständig abstrahiert
- Solange eine Funktion rein (pure) ist, kann man sie mit dem
@jax.jit-Decorator JIT-kompilieren und für XLA nutzbar machen - XLA übernimmt im Hintergrund die Prüfung, ob der erzeugte Graph korrekt ist, den GSPMD-Partitionierer für automatische Parallelisierung mit Sharding in JAX, Graph-Optimierung, Operator- und Kernel-Fusion, Latency-Hiding-Scheduling, asynchrone Kommunikationsüberlappung, Codegenerierung für andere Backends wie triton und mehr
- Solange man die Einschränkungen von JAX einhält, erledigt XLA alles automatisch
- Beim Parallelisieren braucht man zum Beispiel keine Kommunikationsprimitive wie
torch.distributed.barrier() - DDP-Unterstützung ist mit einfachem Code möglich
- Der Ansatz von XLA ist, dass die Berechnung dem Sharding folgt. Wenn also ein Eingabearray entlang einer bestimmten Achse geshardet ist, übernimmt XLA automatisch die nachgelagerten Berechnungen
- Die Idee der „compilerbasierten Entwicklung“ ähnelt der Arbeitsweise des Rust-Compilers
- Die Grenzen von PyTorch
- Unzufriedenheit darüber, dass sich die PyTorch-Entwickler entschieden haben, für neue Features den Compiler-Stack zu integrieren und sich auf ihn zu stützen, statt die Kernphilosophie von Flexibilität und Freiheit beizubehalten
- Laut offizieller Roadmap von PyTorch 2.x gibt es klar einen langfristigen Plan, XLA vollständig in Torch zu integrieren
- Das ist eine schreckliche Idee. Es ist, als würde man behaupten, in den Rust-Compiler hineingezwungener C++-Code biete eine bessere Erfahrung, als Rust direkt zu verwenden
- Torch wurde im Gegensatz zu JAX nicht um XLA herum entworfen
- Wenn PyTorch sich für einen XLA-basierten Compiler-Stack entscheidet, wäre dann nicht ein ideales Framework eines, das speziell darum herum entworfen und gebaut wurde?
- Selbst wenn PyTorch einen „Multi-Backend“-Ansatz verfolgt, bei dem man das gewünschte Compiler-Backend wählen kann: Würde das nicht das Fragmentierungsproblem verschärfen und die API völlig ruinieren, weil man versucht, die Einschränkungen aller Compiler-Stacks zu respektieren?
- Jeder, der Torch/XLA auf TPUs benutzt hat, leidet unter schwerem PTSD
Multi-Backend ist gescheitert
- PyTorch wollte alles gleichzeitig machen und ist dabei kläglich gescheitert
- Die Designentscheidung „Multi-Backend“ verschlimmert dieses Problem exponentiell
- In der Theorie klingt es so, als könne man einfach den gewünschten Stack wählen, in der Praxis ist es jedoch ein verworrenes Chaos aus schwer verständlichen Tracebacks und Inkompatibilitätsproblemen
- Einschränkungen zwischen Backends und Konflikte mit der PyTorch-API
- Es ist nicht nur schwierig, diese Backends überhaupt zum Laufen zu bringen; die Einschränkungen, die sie erwarten, passen auch schlecht zur flexiblen, pythonischen API von PyTorch
- Es gibt einen Trade-off zwischen der Wahrung von API-Konsistenz und dem Einhalten der Backend-Beschränkungen
- Infolgedessen verlassen sich Entwickler stärker auf Codegenerierung, statt sich wirklich mit einem einzelnen Backend zu integrieren bzw. darauf festzulegen
- Das Fehlen einer Strategie bei PyTorch
- Weil PyTorch sinnvolle Trade-offs ablehnt, fühlt sich jede Entscheidung wie ein Kompromiss an
- Es gibt weder Konsistenz noch eine übergreifende Strategie
- Letztlich sorgt das für viel Frustration bei den Nutzern und wirkt wie ein Sammelsurium von Features, die nicht zusammenpassen
- Einen schnelleren Weg, ein Ökosystem zu töten, gibt es kaum
- Warum man dem JAX-Ansatz nicht folgen sollte
- PyTorch sollte nicht den JAX-Ansatz eines „integrierten Compilers und Backends“ übernehmen
- Denn JAX wurde explizit dafür entworfen, mit XLA zu arbeiten
- Es kann keine Strategie sein, das PyTorch-Frontend durch das von JAX zu ersetzen
- Es ist praktisch unmöglich, auf Basis von XLA eine bessere API als JAX zu entwerfen
- Den Entwicklern wird nicht vorgeworfen, neue und andere Ideen auszuprobieren
- Aber wenn PyTorch die Zeit überdauern will, sollte es sich stärker darauf konzentrieren, das Fundament zu stärken, statt schicke neue Features zu liefern, die außerhalb idealer Tutorial-Bedingungen sofort zusammenbrechen
Die Fragmentierung von PyTorch und die funktionale Programmierung von JAX
- Die funktionale API von JAX
- JAX-Funktionen müssen rein (pure) sein, also keine globalen Seiteneffekte haben
- Wie mathematische Funktionen müssen sie bei denselben Daten immer dieselbe Ausgabe liefern, unabhängig vom Ausführungskontext
- Dank dieser Designphilosophie sind JAX-Funktionen gut komponierbar und interoperieren gut miteinander
- Die Entwicklungskomplexität sinkt, und Funktionen sind als konkrete Operationen mit bestimmter Signatur klar definiert
- Wenn die Typen stimmen, ist garantiert, dass die Funktion sofort funktioniert
- Das passt gut zu wissenschaftlichem Rechnen, insbesondere zu den in Deep Learning benötigten Arbeitslasten
- Beispiel für die optax-API
- Dank des funktionalen Ansatzes gibt es in optax etwas namens „chain“
- Darin sind mehrere Funktionen enthalten, die nacheinander auf Gradienten angewendet werden
- Die grundlegende Komponente ist
GradientTransformation - Das ergibt eine leistungsfähige und zugleich ausdrucksstarke API
- So werden Dinge wie Gradient Clipping, das Bilden einer EMA der Gradienten oder das Kombinieren von Optimizern sehr einfach
- Vorteile des funktionalen Designs
- Ein weiteres großartiges Ergebnis des funktionalen Designs ist
vmap - Das steht für „vectorized map“ und beschreibt genau, was es tut
- Man kann alles mappen, und solange es
vmapist, fusioniert und optimiert XLA automatisch - Beim Schreiben von Funktionen muss man nicht an Batch-Dimensionen denken
- Man muss nur den gesamten Code
vmapen - Das bedeutet, dass weniger ein-* Operationen nötig sind
- Das Verarbeiten von 2D-/3D-Tensoren ist intuitiver und deutlich lesbarer
- Weil man nur einzelne Komponenten isolieren und darüber nachdenken muss, lässt sich komplexer Code, der gut funktioniert, leichter schreiben
- Solange man die Reinheitsanforderungen respektiert und die richtige Signatur hat, erhält man alle weiteren Vorteile wie Komponierbarkeit
- Ein weiteres großartiges Ergebnis des funktionalen Designs ist
- Probleme des PyTorch-Ökosystems
- In
torchkann immer etwas kaputtgehen, unabhängig davon, welchen Stack man nutzt (FSDP+ Multi-Node +torch.compileusw.) - Viele Dinge müssen korrekt zusammenspielen, und wenn eine einzige Komponente versagt, debuggt man bis 3 Uhr morgens
- Weil man nicht alle Kombinationen der Dutzenden von Features testen kann, die PyTorch bietet, wird es immer Bugs geben, die während der Entwicklung nicht entdeckt wurden
- Ohne erheblichen Aufwand ist es unmöglich, Code zu schreiben, der zuverlässig gut funktioniert
- Das
torch-Ökosystem ist extrem aufgebläht und fehleranfällig geworden - Weil es keine gemeinsame Abstraktion gibt, entstehen neue Bibliotheken und Frameworks, die nicht dafür entworfen wurden, mit anderen „Lösungen“ zu interagieren
- Das verkommt schnell zu einem Chaos aus Abhängigkeiten und
requirements.txt - 70–80 % der GitHub-Issues oder Forendiskussionen entstehen schlicht dadurch, dass verschiedene Bibliotheken Fehler werfen
- Es gibt kaum eine Möglichkeit, das zu beheben
- In
- Das Fehlen einer Lösung
- Das ist ein Problem von OOP und Design
- Grundlegende, PyTorch-artige Objekte wie
PyTreehätten vermutlich geholfen, eine gemeinsame Basis für Abstraktion zu schaffen - Auch ein Wechsel zum Paradigma der funktionalen Programmierung ist nicht möglich
- Das würde zu einer schlechter performenden Version von JAX führen und gleichzeitig die Abwärtskompatibilität zu allen bestehenden torch-Codebasen brechen
- PyTorch wirkt an diesem Punkt völlig kaputt
Der Vorteil von JAX bei der Reproduzierbarkeit
- Umgang mit Seeds
- Der Umgang mit Seeds in PyTorch ist nicht ideal
- In der Regel muss man mehrere Codezeilen ausführen
- Das vergisst man leicht oder konfiguriert es falsch
- JAX erzwingt, dass man explizite Keys erzeugt und sie an jede Funktion übergibt, die Zufälligkeit benötigt
- Dieser Ansatz beseitigt das Problem vollständig, weil der RNG immer statisch geseedet ist
- JAX hat seine eigene Version von NumPy (
jax.numpy), daher muss man den Seed nicht separat setzen - Solche kleinen QoL-Entscheidungen können die User Experience des gesamten Frameworks deutlich verbessern
- Portabilität
- Eines der größten Probleme bei PyTorch-Codebasen ist ihre mangelnde Portabilität
- Codebasen, die für CUDA/GPU geschrieben wurden, funktionieren auf nicht-Nvidia-Hardware wie TPU, NPU oder AMD-GPU oft nicht gut
- PyTorch-Code für einen einzelnen Node lässt sich nur schwer auf Multi-Node portieren
- Multi-Node erfordert oft Dutzende Stunden Entwicklungszeit und erhebliche Codeänderungen
- Der compilerzentrierte Ansatz von JAX hat hier Vorteile
- XLA übernimmt den Wechsel zwischen Geräte-Backends und funktioniert mit minimalen Codeänderungen gut auf GPU/TPU/Multi-Node/Multi-Slice
- Das erleichtert Hardwareanbietern die Unterstützung ihrer Geräte und macht den Wechsel zwischen Geräten einfacher
- Da nicht jeder Zugriff auf dieselbe Hardware hat, könnten portable Codebasen für verschiedene Hardwaretypen ein kleiner Schritt sein, um Deep Learning für Einsteiger und Fortgeschrittene zugänglicher zu machen
- Automatische Skalierung
- Eine Codebasis, die aus sich heraus gut automatisch skaliert, hilft sehr bei der Reproduzierbarkeit
- Idealerweise sollte das mit minimalen Codeänderungen und unabhängig von Netzwerkgrenzen automatisch geschehen
- JAX macht das gut
- Beim Schreiben von JAX-Code muss man keine Kommunikationsprimitive festlegen oder überall
torch.distributed.barrier()einfügen - XLA fügt dies automatisch unter Berücksichtigung der verfügbaren Hardware ein
- Alle Geräte, die JAX erkennen kann, werden unabhängig von Networking, Topologie und Konfiguration automatisch genutzt
- Die Berechnung wird automatisch synchronisiert und vorbereitet, und es werden Optimierungspässe angewendet, um die asynchrone Ausführung von Kernels zu maximieren und Latenzen zu minimieren
- Das Einzige, was der Mensch tun muss, ist das Sharding der Tensoren anzugeben, die auf Geräte verteilt werden sollen, etwa die Batch-Dimension von Eingabearrays
- Wegen des XLA-Ansatzes „Berechnung folgt dem Sharding“ wird der Rest automatisch ermittelt
- Dadurch lassen sich im Hobbybereich leicht validierte Experimente in größerem Maßstab ausführen, um zu experimentieren und sie potenziell zu wiederholen
- Das könnte die Wiederentdeckung vergessener Ideen erleichtern und solche Experimente fördern, weil man sie mit minimalem Aufwand als Funktion in größerem Maßstab testen kann
Nachteile von JAX
- Governance-Struktur
- XLA steht derzeit unter der Governance von TensorFlow
- Es gab Diskussionen über die Einrichtung eines separaten Gremiums ähnlich wie bei PyTorch, aber konkrete Anstrengungen blieben bisher weitgehend aus
- Wegen Googles Ruf, unpopuläre Produkte einzustellen, ist das Vertrauen in Google nicht besonders hoch
- JAX ist technisch gesehen ein DeepMind-Projekt und hat zentrale Bedeutung für Googles gesamten AI-Vorstoß, aber eine langfristig größere Unabhängigkeit wäre für das gesamte Ökosystem von großem Vorteil
- Ein separates Governance-Gremium würde der Projektentwicklung Orientierung geben
- Das würde eine konkrete Struktur schaffen und könnte viele Probleme auf einmal vermeiden, indem es von Googles berüchtigter Bürokratie entkoppelt wird
- JAX braucht nicht zwingend eine solche formale Struktur, aber eine Zusicherung, dass die Entwicklung von JAX langfristig weitergeht – unabhängig von Entscheidungen des Google-Managements –, wäre wünschenswert
- Das würde die Akzeptanz bei Unternehmen und großen Forschungslaboren klar fördern, die zögern, Ressourcen in ein Tool zu investieren, das eines Tages womöglich nicht mehr gepflegt wird
- Die Open-Source-Transformation von XLA
- XLA war lange Zeit ein Closed-Source-Projekt
- Es wurden jedoch Anstrengungen unternommen, es Open Source zu machen, und inzwischen zeigt OpenXLA deutlich bessere Performance als interne XLA-Builds
- Dennoch fehlt es weiterhin an Dokumentation über das Innenleben von XLA
- Die meisten Ressourcen bestehen nur aus Live-Talks und gelegentlichen Papers, die oft veraltet sind
- Eine öffentlich zugängliche Roadmap für geplante Features würde es einfacher machen, Fortschritte zu verfolgen und insbesondere zu interessanten Themen beizutragen
- Mini-Blogposts im Stil von Edward Yang, die jede Phase des XLA-Compiler-Stacks analysieren und Details erklären, wären hilfreich, damit Praktiker besser einschätzen können, was XLA kann und was nicht
- Zwar ist das ressourcenintensiv und lässt sich vielleicht anderswo besser kommunizieren, aber Menschen vertrauen Tools mehr, wenn sie sie verstehen, und das hätte positive Auswirkungen auf das gesamte Ökosystem
- Integration des Ökosystems
- flax ist ein Ärgernis im JAX-Ökosystem
- Es hat eine unintuitive API, eine knappe Syntax und ist für Einsteiger, die von PyTorch wechseln, die reinste Hölle
- Die Empfehlung ist, equinox zu verwenden
- Es gab Versuche des Entwicklerteams, die Schwächen von flax zu beheben, aber letztlich ist das Zeitverschwendung
- Wenn man eine API im Stil von equinox will, sollte man einfach equinox verwenden
- Es gibt nicht viel, was flax besonders besser kann, und es ist nicht schwer, das mit equinox nachzubilden
- Derzeit ist ein großer Teil des JAX-Ökosystems auf flax ausgerichtet
- equinox interoperiert mit allen Bibliotheken, weil es grundlegend mit
PyTreearbeitet, auch wenn dafür etwaseqx.partitionundfilternötig ist - Der Status quo sollte sich ändern. equinox sollte überall erstklassig unterstützt werden
- Das ist eine kontroverse Meinung, aber letztlich ein klassischer Fall der Sunk-Cost-Fallacy
- equinox funktioniert besser so, wie ein JAX-Framework immer hätte funktionieren sollen
- Vergleicht man equinox und flax wie in der equinox-Dokumentation zusammengefasst, ist equinox besser
- Es ist gut, dass die Maintainer des JAX-Ökosystems die Popularität von equinox erkennen und entsprechend reagieren, aber auch von Google und dem flax-Team wäre offiziell mehr Unterstützung wünschenswert
- Wer JAX ausprobieren möchte, sollte equinox verwenden
- Raue Kanten
- Aufgrund von API-Designentscheidungen und XLA-Einschränkungen hat JAX einige „raue Kanten“, auf die man achten muss
- In der gut geschriebenen Dokumentation wird das sehr prägnant erklärt
- Es empfiehlt sich, das vor der Nutzung von JAX mindestens einmal zu lesen
- Wie immer spart
RTFMviel Zeit und Energie
Fazit
- Dieser Blogpost sollte mit dem oft wiederholten Mythos aufräumen, PyTorch sei für reale Forschungs-Workloads, insbesondere auf GPUs, am besten geeignet. Das ist nicht mehr der Fall
- Tatsächlich geht der Autor so weit zu behaupten, dass es für das gesamte Feld enorm vorteilhaft wäre, sämtlichen PyTorch-Code nach JAX zu portieren
- Automatische Parallelisierung, Reproduzierbarkeit und eine saubere funktionale API sind keine Nebensächlichkeiten und würden vielen Forschungs-Codebasen sehr helfen
- Wenn man dieses Feld auch nur ein wenig verbessern will, sollte man in Betracht ziehen, die eigene Codebasis in JAX neu zu schreiben
8 Kommentare
Die Welt dreht sich weiter. Haha.
Vergleich von PyTorch und TensorFlow im Jahr 2022
Ich halte mich mit torch und onnx über Wasser.
Von einem Bachelorstudenten geschriebener Artikel … krass.
Ohne Hugging Face wäre PyTorch echt tot, lol.
Es lebe JAX! Ich habe es kürzlich ausprobiert, und die NNX-API hat mir sehr gut gefallen.
Das größte Problem von JAX ist, dass es von Google stammt. Google ist ziemlich berüchtigt dafür, Open-Source-Projekte fallen zu lassen (
Tflite,android things,dart,angular,bazelusw.), und auch bei TensorFlow begannen die Updates irgendwann zu stocken. Torch hingegen stammt von Facebook, das ein riesiges Open-Source-Ökosystem betreibt, wurde sehr gut gepflegt und steht bereits unter der Verwaltung der Torch Foundation. Die Schwächen von Torch treffen in Teilen sicherlich zu, aber bei der Frage, wer ein solches Open-Source-Projekt nachhaltig betreiben kann, scheint JAX schon von Anfang an ein großes Risiko mitzubringen.Zumindest scheint Dart dank Flutter noch eine Weile gut leben zu können.
Facebook scheint seinen eigenen Technologie-Stack, den sie selbst nutzen – etwa React, Django usw. – immerhin loyal(?) kontinuierlich weiter zu unterstützen und dazu beizutragen, aber bei Google wirkt es so, als würden sie etwas, sobald es auch nur ein bisschen veraltet ist, wie einen alten Lappen wegwerfen...