FlashAttention-3: Schnelle und präzise Attention durch Asynchronität und geringe Präzision
(together.ai)-
Die Bedeutung von Attention
- Attention ist die zentrale Schicht der Transformer-Architektur und verursacht Engpässe bei großen Sprachmodellen und Anwendungen mit langem Kontext.
- FlashAttention und FlashAttention-2 haben einen Ansatz etabliert, der Attention beschleunigt, indem Lese- und Schreibzugriffe auf den GPU-Speicher minimiert werden.
- Dadurch konnte die Kontextlänge von LLMs deutlich erhöht werden.
-
Die wichtigsten Techniken von FlashAttention-3
- Nutzung von Asynchronität: Nutzt die Asynchronität von Tensor Cores und TMA, um Berechnungen und Datenbewegungen vollständig zu überlappen.
- Blockweise Verarbeitung: Führt blockweise Matrixmultiplikation und Softmax-Operationen im Wechsel aus.
- Verarbeitung mit geringer Präzision: Verbessert die Leistung durch Unterstützung für FP8 mit geringer Präzision.
-
Leistungssteigerungen von FlashAttention-3
- Effiziente GPU-Auslastung: Nutzt bis zu 75 % der Spitzenleistung der H100-GPU und ist damit 1,5- bis 2-mal schneller als die vorherige Version.
- Leistung bei geringer Präzision: Erhöht mit FP8 die Verarbeitungsgeschwindigkeit und reduziert den Speicherverbrauch.
- Verarbeitung langer Kontexte: Beschleunigt den Attention-Mechanismus, sodass längere Texte effizient verarbeitet werden können.
-
Zusammenfassung von FlashAttention
- FlashAttention ordnet die Attention-Berechnung neu an und nutzt Tiling sowie Rekombination, um die Geschwindigkeit deutlich zu erhöhen und den Speicherverbrauch zu senken.
- Durch Tiling werden Eingabeblöcke geladen, darauf Attention ausgeführt und anschließend die Ausgabe aktualisiert.
- Da die Zwischenmatrix der Attention nicht in den Speicher geschrieben wird, verringert sich die Menge an Speicher-Lese- und Schreibzugriffen.
-
Neue Hardware-Funktionen der Hopper-GPU
- WGMMA: Liefert hohen Durchsatz durch die Nutzung neuer Tensor Cores.
- TMA: Hardware-Einheit, die die Datenübertragung zwischen globalem Speicher und Shared Memory beschleunigt.
- FP8 mit geringer Präzision: Verdoppelt den Durchsatz der Tensor Cores durch den Einsatz von FP8.
-
Asynchronität: Überlappung von GEMM und Softmax
- Warum Überlappung nötig ist: Führt GEMM und Softmax parallel aus, um die Leistung zu maximieren.
- Ping-Pong-Scheduling: Zwei Warp-Gruppen führen abwechselnd GEMM und Softmax aus, um die Leistung zu verbessern.
- Überlappung innerhalb einer Warp-Gruppe: Führt GEMM und Softmax innerhalb derselben Warp-Gruppe parallel aus und erhöht so den Durchsatz.
-
Geringe Präzision: Reduzierung von Quantisierungsfehlern durch inkohärente Verarbeitung
- Inkohärente Verarbeitung: Reduziert Quantisierungsfehler mithilfe der Hadamard-Transformation.
- Experimentelle Ergebnisse: Verringert Quantisierungsfehler durch inkohärente Verarbeitung um das 2,6-Fache.
-
Attention-Benchmarks
- FP16: Etwa 1,6- bis 1,8-mal schneller als FlashAttention-2.
- FP8: Erreicht bis zu 1,2 PFLOPS.
Zusammenfassung von GN⁺
- FlashAttention-3 verbessert die Leistung des Attention-Mechanismus erheblich, indem es neue Hardware-Funktionen von GPUs nutzt.
- Es kann lange Kontexte effizient verarbeiten und maximiert damit die Leistung großer Sprachmodelle.
- Es wird voraussichtlich in wichtige Frameworks wie PyTorch integriert und dürfte dadurch großen Einfluss auf zukünftige KI-Forschung und Anwendungen haben.
- Projekte mit ähnlicher Funktionalität sind unter anderem Triton und cuDNN.
1 Kommentare
Hacker-News-Kommentare
Es scheint, dass Tri Dao bereits im April 2022 mit der Arbeit an FA3 begonnen hat
Es stellt sich die Frage, wie hardwareabhängig der Flash-Attention-Algorithmus ist
Es stellt sich die Frage, ob Compiler Optimierungen wie bei FlashAttention selbstständig finden können
Wer an einem Port auf ROCm/AMD MI300x interessiert ist, soll sich melden
TMA (Tensor Memory Accelerator) ist eine Hardware-Einheit, die die Datenübertragung zwischen globalem Speicher und Shared Memory beschleunigt
FlashAttention-3 ist für Hopper-GPUs (z. B. H100) optimiert
Es wird erwähnt, dass Aktivierungsfunktionen wie sigmoid in modernen LLMs sehr langsam sind
Es stellt sich die Frage, warum Flash Attention mit variablem Masking fünfmal langsamer ist als ohne
Es stellt sich die Frage, ob FlashAttention die Attention-Operationen in LLMs ersetzen kann
Teure Hardware ist erforderlich