2 Punkte von GN⁺ 2024-03-12 | 1 Kommentare | Auf WhatsApp teilen
  • Diffusionsmodelle werden über die Bilderzeugung hinaus für Probleme eingesetzt, die Sampling aus multimodalen Verteilungen erfordern, etwa Audio, Video, 3D, Proteindesign und Roboter-Pfadplanung. Dieses Tutorial verbindet Training und Sampling aus Optimierungsperspektive
  • Im Training werden verrauschte Daten (x_\sigma=x_0+\sigma\epsilon) erzeugt, und ein neuronales Netz (\epsilon_\theta(x,\sigma)) minimiert den mittleren quadratischen Fehler, um die Rauschrichtung vorherzusagen
  • Der gelernte Denoiser wird als approximative Projektion auf die Datenmenge (\mathcal{K}) interpretiert; der ideale Denoiser steht mit dem Gradienten einer (\sigma)-geglätteten quadratischen Distanzfunktion in Verbindung
  • DDIM-Sampling lässt sich als approximativer Gradientenabstieg für (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) auffassen; der (\sigma_t)-Schedule bestimmt die Zahl der Iterationen und die Kosten der Denoiser-Auswertung
  • Kombiniert man Updates zur Gradientenschätzung mit dem Hinzufügen von Rauschen, lassen sich DDIM, DDPM und der verbesserte Sampler der Autoren gemeinsam über die Parameter gam und mu behandeln; anschließend folgen ein Toy-Modell und Beispiele für Latent Diffusion

Diffusionsmodelle aus Optimierungsperspektive

  • Diffusionsmodelle sind besonders stark darin, Samples aus multimodalen Verteilungen zu erzeugen, und werden nicht nur in Text-zu-Bild-Tools wie Stable Diffusion, sondern auch für Audio, Video, 3D-Erzeugung, Proteindesign und Roboter-Pfadplanung eingesetzt
  • Die theoretische Grundlage des Tutorials ist die Optimierungsinterpretation aus einem ICML-2024-Paper und einem verwandten Paper
  • Die Implementierung orientiert sich vor allem an smalldiffusion; der Code im Text ist zu Lehrzwecken gegenüber der ursprünglichen Bibliothek vereinfacht

Training: Vorhersage der Rauschrichtung

  • Diffusionsmodelle zielen darauf ab, aus Trainingsbeispielen eine Datenmenge (\mathcal{K}) zu lernen und aus dieser Menge Samples zu erzeugen
    • Bei Bildern ist (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) die Menge von Pixelwerten, die realistischen Bildern entsprechen
    • Derselbe Rahmen gilt auch für Audio, Video, Robotertrajektorien und diskrete Domänen wie Text
  • Das Trainingsverfahren lässt sich in drei Schritte zerlegen
    • (x_0 \sim \mathcal{K}), (\sigma) und (\epsilon \sim N(0,I)) werden gesampelt
    • Mit (x_\sigma=x_0+\sigma\epsilon) werden verrauschte Daten erzeugt
    • Die quadratische Loss wird minimiert, sodass (\epsilon_\theta(x_\sigma,\sigma)) das (\epsilon) vorhersagt
  • Im Code erzeugt training_loop für jedes Batch x0 mit generate_train_sample ein sigma und ein eps und optimiert den MSE zwischen der Ausgabe von model(x0 + sigma * eps, sigma) und eps
  • (\sigma) wird nicht gleichverteilt aus einem kontinuierlichen Intervall gesampelt, sondern aus einem in (N) Werte diskretisierten (\sigma)-Schedule gezogen
    • Die Klasse Schedule kapselt die Liste möglicher sigmas und sampelt während des Trainings batchweise Werte daraus
    • Das Beispiel im Text verwendet ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10)
    • ScheduleDDPM ist ein Schedule für Diffusionsmodelle im Pixelraum, ScheduleLDM für Latent-Diffusion-Modelle wie Stable Diffusion

Swissroll-Toy-Beispiel

  • Das Toy-Dataset ist eine spiralförmige Punktmenge, wie sie in einem der frühen Diffusions-Paper von Sohl-Dickstein et al. 2015 verwendet wurde; es gilt (\mathcal{K}\subset\mathbb{R}^2)
  • Bei einfachen Datasets wird der Denoiser als MLP implementiert
    • Die Eingabe ist die Konkatenation von (x\in\mathbb{R}^2) und einem zweidimensionalen Embedding von (\sigma)
    • Die Ausgabe ist die Vorhersage für das Rauschen (\epsilon\in\mathbb{R}^2)
    • Viele Diffusionsmodelle verwenden für (\sigma) sinusoidale Positional Embeddings, aber in diesem Beispiel funktioniert auch ein einfaches zweidimensionales Embedding gut
  • Die Beispiel-Trainingseinstellung verwendet ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10) und epochs=15000
  • Der gelernte Denoiser kann visualisiert werden, indem man (x-\sigma\epsilon_\theta(x,\sigma)) als Vektorfeld zeichnet
    • Bei großem (\sigma) tendiert der Denoiser dazu, den Datenmittelwert vorherzusagen
    • Bei kleinem (\sigma) und wenn die Eingabe (x) nahe an den Daten liegt, sagt er den tatsächlichen Datenpunkt voraus

Denoising als Projektion interpretieren

  • Die Distanzfunktion zur Datenmenge (\mathcal{K}) ist definiert als (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
  • Die Projektion (\mathrm{proj}_{\mathcal{K}}(x)) von (x) ist die Menge der Punkte in (\mathcal{K}), die diese Distanz erreichen
  • Wenn (\mathcal{K}) abgeschlossen ist, (x\notin\mathcal{K}) gilt und die Projektion eindeutig ist, ist der Gradient der quadratischen Distanzfunktion (x-\mathrm{proj}_{\mathcal{K}}(x))
  • Da die Distanzfunktion (\mathrm{dist}_{\mathcal{K}}) nicht überall differenzierbar ist, wird statt min ein softmin verwendet, um eine mit (\sigma) geglättete quadratische Distanzfunktion einzuführen
  • Der Gradient der geglätteten Distanzfunktion zeigt in Richtung des gewichteten Mittels der Punkte von (\mathcal{K}), abhängig von den durch (x) bestimmten Gewichten

Idealer Denoiser und relatives Fehlermodell

  • Der ideale Denoiser (\epsilon^*) ist ein Denoiser, der die Trainings-Loss für ein bestimmtes (\sigma) exakt minimiert
  • Wenn die Daten eine diskrete Gleichverteilung auf einer endlichen Menge (\mathcal{K}) sind, lässt sich der ideale Denoiser in geschlossener Form ausdrücken
    • Das Gewicht jedes Datenpunkts wird durch die Distanz zwischen (x_\sigma) und diesem Punkt bestimmt
    • Bei kleinen Datasets kann er mit IdealDenoiser direkt berechnet werden
  • Auf Toy-Daten zeigt der ideale Denoiser bei großem (\sigma) zum Datenmittelwert und bei kleinem (\sigma) zum nächstgelegenen Datenpunkt
  • Der zentrale Satz stellt für alle (\sigma>0), (x\in\mathbb{R}^n) die Beziehung (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) her
  • Das relative Fehlermodell verwendet die Bedingung, dass (x-\sigma\epsilon_\theta(x,\sigma)) (\mathrm{proj}_{\mathcal{K}}(x)) gut approximiert
    • Es gilt, wenn (\sqrt{n}\sigma) (\mathrm{dist}_{\mathcal{K}}(x)) bis auf einen konstanten Faktor gut schätzt
    • Es wird angenommen, dass der Fehler auf höchstens (\eta\mathrm{dist}_{\mathcal{K}}(x)) begrenzt ist
    • Bei niedrigem Rauschen ist unter der Manifold Hypothesis das meiste zusätzliche Rauschen orthogonal zur Datenmannigfaltigkeit, sodass Denoising die Projektion approximiert
    • Bei hohem Rauschen hat selbst ein Denoiser, der den gewichteten Mittelwert der Daten vorhersagt, einen kleinen relativen Fehler, wenn (\sigma) größer ist als der Durchmesser von (\mathcal{K})
  • CIFAR-10 ist klein genug, um den idealen Denoiser zu berechnen; in Experimenten ist der relative Fehler zwischen der exakten Projektion entlang der Sampling-Trajektorie und der Ausgabe des idealen Denoisers klein

Sampling: iteratives Denoising und DDIM

  • Hat man einen gelernten Denoiser, wird bei verrauschtem (x_t) und Rauschlevel (\sigma_t) mit (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t)) das (x_0) vorhergesagt
  • Der Startpunkt wird so gewählt, dass (\sigma_T) im Vergleich zum Durchmesser von (\mathcal{K}) groß ist, und (x_T) wird unabhängig aus (N(0,\sigma_T)) gesampelt, sodass er weit von (\mathcal{K}) entfernt liegt
  • Bei hohem Rauschen kann ein einzelner Denoiser-Aufruf trotz kleinem relativen Fehler einen großen absoluten Fehler haben; die Vorhersage des idealen Denoisers liegt nahe am Datenmittelwert
  • Deshalb ruft das Sampling den Denoiser wiederholt entlang eines (\sigma_t)-Schedules auf und erzeugt eine Sequenz (x_T,\ldots,x_0)
  • Das Update (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) entspricht nach einer Koordinatentransformation dem deterministischen DDIM-Sampling-Algorithmus
    • Der Beweis der Äquivalenz zu DDIM steht in Appendix A des Papers

DDIM als Distanzminimierung

  • DDIM wird als approximativer Gradientenabstieg für (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) interpretiert
    • Die Schrittweite ist (1-\sigma_{t-1}/\sigma_t)
    • (\nabla f(x_t)) wird durch (\epsilon_\theta(x_t,\sigma_t)) geschätzt
  • Der (\sigma_t)-Schedule bestimmt beim Sampling die Anzahl und Größe der Gradientenschritte
    • Sind es zu wenige Schritte, verringert sich (\mathrm{dist}_{\mathcal{K}}(x_t)) möglicherweise nicht, sodass keine Konvergenz eintritt
    • Viele kleine Schritte erhöhen die Zahl der Denoiser-Auswertungen und damit die Rechenkosten
  • Ein admissible Schedule ist ein Schedule, bei dem in jeder Iteration (\sqrt{n}\sigma_t) bis auf einen konstanten Faktor zu (\mathrm{dist}_{\mathcal{K}}(x_t)) passt
    • Eine geometrisch abnehmende log-lineare (\sigma_t)-Sequenz ist ein admissible Schedule
  • Dem Satz zufolge gilt: Wenn für die durch DDIM erzeugten (x_t) der Gradient (\nabla\mathrm{dist}{\mathcal{K}}(x)) existiert und (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T) ist, dann wird (x_t) durch Gradientenabstieg auf der quadratischen Distanzfunktion erzeugt, und (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) bleibt erhalten
  • Im Toy-Beispiel wird ein DDIM-Sampler mit 20 Schritten implementiert, indem aus dem ursprünglichen log-linearen Schedule subsampelt wird; die meisten Samples liegen nahe an den Originaldaten, aber es bleibt Raum für Verbesserungen

Verbesserter Sampler auf Basis von Gradientenschätzung

  • Ausgenutzt wird, dass (\nabla\mathrm{dist}{\mathcal{K}}(x)) zwischen (x) und (\mathrm{proj}{\mathcal{K}}(x)) invariant ist; dafür wird ein Update verwendet, das die aktuelle und die vorherige Schätzung mischt
  • Das Update (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) korrigiert den Fehler des vorherigen Schritts mit der aktuellen Schätzung
  • In Samples des Toy-Modells konvergiert diese Methode schneller als DDIM, und die Samples liegen näher an den Originaldaten
  • Im Vergleich zu DDIM lässt sich dieser Sampler als Erweiterung um Momentum interpretieren; die Trajektorie kann overshooten, konvergiert aber möglicherweise schneller
  • Fügt man während der Generierung Rauschen hinzu, verbessert sich die Sampling-Qualität empirisch
    • Um den ursprünglichen (\sigma_t)-Schedule beizubehalten, denoist man zunächst bis zu einem kleineren (\sigma_{t'}) und fügt dann wieder Rauschen (w_t\sim N(0,I)) hinzu
    • Für (\mu=\frac{1}{2}) wird der DDPM-Sampler exakt wiederhergestellt
  • Das vollständige Update (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) verallgemeinert drei Sampler
    • DDIM: gam=1, mu=0
    • DDPM: gam=1, mu=0.5
    • Gradientenschätzungs-Sampler: gam=2, mu=0

Größere Modelle und Referenzen

  • Der obige Trainingscode kann nicht nur für Toy-Daten verwendet werden, sondern auch, um Bild-Diffusionsmodelle von Grund auf zu trainieren
  • Das FashionMNIST-Beispiel wird als Beispiel bereitgestellt, das auf dem FashionMNIST-Dataset trainiert und nach FID den zweiten Platz im Papers with Code Leaderboard erreicht
  • Der Sampling-Code kann unverändert auch für vortrainierte Latent-Diffusion-Modelle verwendet werden
    • Das Beispiel nutzt ScheduleLDM(1000) und ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
    • Als Textbedingung wird An astronaut riding a horse gesetzt; nach Sampling mit 50 (\sigma)-Schritten wird das Latent dekodiert
  • Der Effekt des (\gamma)-Momentum-Terms wird in einer visuellen Gegenüberstellung bei hochauflösender Text-zu-Bild-Generierung gezeigt
  • Weitere lesenswerte Materialien

1 Kommentare

 
GN⁺ 2024-03-12
Hacker-News-Kommentare
  • Ich bin der Autor. Beim Versuch, Diffusionsmodelle zu verstehen, ist mir klar geworden, dass sich Code und Mathematik stark vereinfachen lassen. Deshalb habe ich diesen Blogbeitrag und die Diffusionsbibliothek erstellt.
    Wenn es Fragen gibt, kann ich sie beantworten.
    • Aus Forschersicht gefallen mir viele Blogbeiträge über Diffusionsmodelle nicht, aber dieser hier war wirklich gut. Er kommt direkt zum Kern, zeigt zugleich die komplexen Stellen, an denen man oft hängen bleibt, ohne sich zu verlieren oder abzuschweifen.
      Besonders gut fand ich die Diskussion der Trajektorien, weil sie motiviert, den Teil zu verstehen, mit dem viele bei Themen wie Schedulern Schwierigkeiten haben. Er ist zwar nicht so vollständig wie die Beiträge von Song oder Lilian, aber deutlich zugänglicher, daher werde ich ihn anderen empfehlen.
      Als Hinweis: Ein Freund hat früher eine Minimalimplementierung von Diffusion geschrieben, die aus DDPM-Perspektive etwas „vollständiger“ ist und nützlich war: https://github.com/VSehwag/minimal-diffusion/
    • Im letzten Beispielbild scheint der Momentum-Term bei dem digitalen Gemälde des Hauses eher schädlich zu sein. Im Bild mit gamma = 2.0 ist die Tür verschwunden; um die Wirkung des DDIM-Samplers, der Gradienteninformationen nutzt, intuitiv zu verstehen, würden mich daher mehr Details zu diesem Beispiel interessieren.
      Da ich selbst ein wenig mit dem Sampling-Verfahren in Stable Diffusion experimentiert habe, hätte ich auch gern einen Vergleich der Konvergenzzeit und Schrittzahl gegenüber DDIM gesehen. Mich interessiert, ob es einen Zusammenhang zwischen Momentum, Konvergenz und Fehler gibt. Zum Beispiel wäre ein Vergleich interessant, ob ein Momentum-Sampler mit 16 Schritten in etwa DDIM mit 20 Schritten ± Fehlerterm entspricht.
    • get_sigma_embeds(batches, sigma) scheint den ersten Input nicht zu verwenden. Ich frage mich, ob die Absicht war, sigma auf die Form (batches, 1) zu broadcasten.
    • Ich frage mich, ob einige dieser Konzepte aus physikalischen Prinzipien stammen. Ist das ähnlich wie die Aussage, neuronale Netze seien biologischen neuronalen Netzen nachempfunden, oder gibt es dazu eine Einsicht aus dieser Perspektive?
  • Ein weiterer guter Beitrag trägt ebenfalls den Titel Diffusion Models From Scratch: https://www.tonyduan.com/diffusion/index.html
    Er geht viel tiefer auf die mathematischen Details ein und enthält zugleich eine sehr verständliche Minimalimplementierung mit weniger als 500 Zeilen.
  • Schön, dass es Code gibt. Diffusions-Paper sind berüchtigt dafür, viele Gleichungen zu enthalten (https://twitter.com/cto_junior/status/1766518604395155830), aber für den Rest von uns ist Code viel leichter zu lesen und möglicherweise auch präziser. Meiner Meinung nach sollte jedes theoretische Paper Referenzimplementierungscode enthalten.
    Es wäre schön, wenn das auch auf Versionen mit Diffusion-Transformern ausgeweitet würde, wie sie Sora und andere Videogenerierungsmodelle antreiben. Aus diesem Beitrag und https://jaykmody.com/blog/gpt-from-scratch/ ließe sich wohl ein Einführungsartikel „Diffusion Transformer from scratch“ machen.
    • Diffusions-Paper sind zwar dafür berüchtigt, viele Gleichungen zu enthalten, aber ehrlich gesagt reagieren die meisten Diffusionsforscher, die ich kenne, genauso. Viele Leute schreiben dieselben Gleichungen immer wieder hin, und diese Gleichungen dienen im Grunde eher der Wiederholung.
      Wenn man dagegen wirklich tief einsteigen will, empfehle ich die Arbeiten von Kingma, Gao, Ricky Tian Qi Chen sowie den Schülern von Max Welling (Tomczak als Postdoc, Hoogeboom usw.) und dem unterschätzten Aapo Hyvärinen zu lesen. Ein Beispiel aus der vergleichsweise leichtgewichtigeren Arbeit von Kingma & Gao, das auch mit dem SD3-Paper zusammenhängt, ist hier: https://arxiv.org/abs/2303.00848
      Der Nachteil ist, dass die Abhängigkeit davon, frühere Arbeiten zu kennen und zu verstehen, groß ist und dadurch die Zugänglichkeit sinkt. Es ist aber schwer, das als sinnvolle Kritik zu bezeichnen, denn es handelt sich um Forschung und nicht um Lehrmaterial für die breite Öffentlichkeit.
    • Man muss einfach das U-Net durch einen Transformer-Encoder ersetzen. Die Embeddings entfernen, Bild-Patches auf Vektoren der Größe n_embd projizieren, und der Diffusionsprozess selbst kann unverändert bleiben.
  • Ein guter Beitrag, aber mir fehlt die wichtige Eigenschaft[1], dass Diffusionsmodelle die Score-Funktion (die Ableitung der Log-Wahrscheinlichkeit) modellieren, sowie der Punkt, dass Diffusions-Sampling der Langevin-Dynamik[2] ähnelt. Ich denke, diese Perspektiven erklären gut, warum das Training einfacher ist als bei GANs: Das Modellierungsziel ist leichter.
    [1] https://yang-song.net/blog/2021/score/
    [2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
    • Stimmt. Diese Blogbeiträge bieten eine Interpretation von Diffusionsmodellen, die sich von der im Text beschriebenen Perspektive der „Projektion auf die Daten“ unterscheidet. Man kann sie als mehrere Arten sehen, dasselbe Trainingsziel und denselben Sampling-Prozess zu interpretieren.
      Aus unserer Sicht sind Diffusionsmodelle leicht zu trainieren, weil sie als Trainingsziel nicht den Gradienten der exakten Distanzfunktion vorhersagen, sondern den Gradienten einer geglätteten Distanzfunktion. Sampling mit Diffusionsmodellen ist ähnlich wie das Ausführen mehrerer angenäherter Gradientenschritte.
      Wer Diffusionsmodelle tiefer verstehen möchte, dem empfehle ich, all diese Blogbeiträge zu lesen und die unterschiedlichen Interpretationen kennenzulernen.
  • Sehr interessant. Mir kam sofort Iterative alpha-(de)Blending[1] in den Sinn. Auch diese Arbeit versucht, ein konzeptionell einfacheres Diffusionsmodell aufzubauen, und kommt zu dem Schluss, es als approximativen iterativen Projektionsprozess zu formulieren.
    Allerdings scheint der Ansatz dieses Beitrags interessantere Experimente zu ermöglichen, etwa zur Fehleranalyse des Denoisers.
    [1] https://arxiv.org/pdf/2305.03486.pdf
  • Gute theoretische Erklärung. Sie wirkt unabhängig vom Datensatz, aber mich interessieren die konkreten Aspekte der tatsächlichen Bildgenerierung.
    Warum fällt es einem Bildgenerator zum Beispiel schwer, Klaviertasten zu erzeugen? Um die Struktur zu erzeugen, bei der sich schwarze Tasten in Zweier- und Dreiergruppen abwechseln, müsste er anscheinend Zwischen-Distanzbeschränkungen besser darstellen.
    • Das ist wie das Fingerproblem. Anzahl, Größe, Winkel, Position usw. müssen jedes Mal alle stimmen, und wenn auch nur eine Sache nicht passt, merken Menschen das sehr schnell. Das ist anders als bei Objekten wie Baumästen, bei denen Menschen kaum bemerken, wenn die Stelle der Verzweigung „falsch“ ist.
  • Ist ein Teil der Idee hinter Diffusion, die Trainingsdaten enorm zu vergrößern? Also in dem Sinn, dass man zufällig diffundierte Bilder ihren ursprünglichen, nicht diffundierten Bildern gegenüberstellen kann?
  • Alle Machine-Learning-Modelle sind Faltungen. Wartet nur ab.
    • Ich glaube, du hast das schon ein paar Mal gepostet. Kannst du das etwas genauer erklären? Zum Beispiel fällt es mir schwer, Reinforcement Learning als Faltung zu betrachten.