- Diffusionsmodelle werden über die Bilderzeugung hinaus für Probleme eingesetzt, die Sampling aus multimodalen Verteilungen erfordern, etwa Audio, Video, 3D, Proteindesign und Roboter-Pfadplanung. Dieses Tutorial verbindet Training und Sampling aus Optimierungsperspektive
- Im Training werden verrauschte Daten (x_\sigma=x_0+\sigma\epsilon) erzeugt, und ein neuronales Netz (\epsilon_\theta(x,\sigma)) minimiert den mittleren quadratischen Fehler, um die Rauschrichtung vorherzusagen
- Der gelernte Denoiser wird als approximative Projektion auf die Datenmenge (\mathcal{K}) interpretiert; der ideale Denoiser steht mit dem Gradienten einer (\sigma)-geglätteten quadratischen Distanzfunktion in Verbindung
- DDIM-Sampling lässt sich als approximativer Gradientenabstieg für (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) auffassen; der (\sigma_t)-Schedule bestimmt die Zahl der Iterationen und die Kosten der Denoiser-Auswertung
- Kombiniert man Updates zur Gradientenschätzung mit dem Hinzufügen von Rauschen, lassen sich DDIM, DDPM und der verbesserte Sampler der Autoren gemeinsam über die Parameter
gamundmubehandeln; anschließend folgen ein Toy-Modell und Beispiele für Latent Diffusion
Diffusionsmodelle aus Optimierungsperspektive
- Diffusionsmodelle sind besonders stark darin, Samples aus multimodalen Verteilungen zu erzeugen, und werden nicht nur in Text-zu-Bild-Tools wie Stable Diffusion, sondern auch für Audio, Video, 3D-Erzeugung, Proteindesign und Roboter-Pfadplanung eingesetzt
- Die theoretische Grundlage des Tutorials ist die Optimierungsinterpretation aus einem ICML-2024-Paper und einem verwandten Paper
- Die Implementierung orientiert sich vor allem an
smalldiffusion; der Code im Text ist zu Lehrzwecken gegenüber der ursprünglichen Bibliothek vereinfacht
Training: Vorhersage der Rauschrichtung
- Diffusionsmodelle zielen darauf ab, aus Trainingsbeispielen eine Datenmenge (\mathcal{K}) zu lernen und aus dieser Menge Samples zu erzeugen
- Bei Bildern ist (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) die Menge von Pixelwerten, die realistischen Bildern entsprechen
- Derselbe Rahmen gilt auch für Audio, Video, Robotertrajektorien und diskrete Domänen wie Text
- Das Trainingsverfahren lässt sich in drei Schritte zerlegen
- (x_0 \sim \mathcal{K}), (\sigma) und (\epsilon \sim N(0,I)) werden gesampelt
- Mit (x_\sigma=x_0+\sigma\epsilon) werden verrauschte Daten erzeugt
- Die quadratische Loss wird minimiert, sodass (\epsilon_\theta(x_\sigma,\sigma)) das (\epsilon) vorhersagt
- Im Code erzeugt
training_loopfür jedes Batchx0mitgenerate_train_sampleeinsigmaund einepsund optimiert den MSE zwischen der Ausgabe vonmodel(x0 + sigma * eps, sigma)undeps - (\sigma) wird nicht gleichverteilt aus einem kontinuierlichen Intervall gesampelt, sondern aus einem in (N) Werte diskretisierten (\sigma)-Schedule gezogen
- Die Klasse
Schedulekapselt die Liste möglichersigmasund sampelt während des Trainings batchweise Werte daraus - Das Beispiel im Text verwendet
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMist ein Schedule für Diffusionsmodelle im Pixelraum,ScheduleLDMfür Latent-Diffusion-Modelle wie Stable Diffusion
- Die Klasse
Swissroll-Toy-Beispiel
- Das Toy-Dataset ist eine spiralförmige Punktmenge, wie sie in einem der frühen Diffusions-Paper von Sohl-Dickstein et al. 2015 verwendet wurde; es gilt (\mathcal{K}\subset\mathbb{R}^2)
- Bei einfachen Datasets wird der Denoiser als MLP implementiert
- Die Eingabe ist die Konkatenation von (x\in\mathbb{R}^2) und einem zweidimensionalen Embedding von (\sigma)
- Die Ausgabe ist die Vorhersage für das Rauschen (\epsilon\in\mathbb{R}^2)
- Viele Diffusionsmodelle verwenden für (\sigma) sinusoidale Positional Embeddings, aber in diesem Beispiel funktioniert auch ein einfaches zweidimensionales Embedding gut
- Die Beispiel-Trainingseinstellung verwendet
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)undepochs=15000 - Der gelernte Denoiser kann visualisiert werden, indem man (x-\sigma\epsilon_\theta(x,\sigma)) als Vektorfeld zeichnet
- Bei großem (\sigma) tendiert der Denoiser dazu, den Datenmittelwert vorherzusagen
- Bei kleinem (\sigma) und wenn die Eingabe (x) nahe an den Daten liegt, sagt er den tatsächlichen Datenpunkt voraus
Denoising als Projektion interpretieren
- Die Distanzfunktion zur Datenmenge (\mathcal{K}) ist definiert als (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- Die Projektion (\mathrm{proj}_{\mathcal{K}}(x)) von (x) ist die Menge der Punkte in (\mathcal{K}), die diese Distanz erreichen
- Wenn (\mathcal{K}) abgeschlossen ist, (x\notin\mathcal{K}) gilt und die Projektion eindeutig ist, ist der Gradient der quadratischen Distanzfunktion (x-\mathrm{proj}_{\mathcal{K}}(x))
- Da die Distanzfunktion (\mathrm{dist}_{\mathcal{K}}) nicht überall differenzierbar ist, wird statt
minein softmin verwendet, um eine mit (\sigma) geglättete quadratische Distanzfunktion einzuführen - Der Gradient der geglätteten Distanzfunktion zeigt in Richtung des gewichteten Mittels der Punkte von (\mathcal{K}), abhängig von den durch (x) bestimmten Gewichten
Idealer Denoiser und relatives Fehlermodell
- Der ideale Denoiser (\epsilon^*) ist ein Denoiser, der die Trainings-Loss für ein bestimmtes (\sigma) exakt minimiert
- Wenn die Daten eine diskrete Gleichverteilung auf einer endlichen Menge (\mathcal{K}) sind, lässt sich der ideale Denoiser in geschlossener Form ausdrücken
- Das Gewicht jedes Datenpunkts wird durch die Distanz zwischen (x_\sigma) und diesem Punkt bestimmt
- Bei kleinen Datasets kann er mit
IdealDenoiserdirekt berechnet werden
- Auf Toy-Daten zeigt der ideale Denoiser bei großem (\sigma) zum Datenmittelwert und bei kleinem (\sigma) zum nächstgelegenen Datenpunkt
- Der zentrale Satz stellt für alle (\sigma>0), (x\in\mathbb{R}^n) die Beziehung (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) her
- Das relative Fehlermodell verwendet die Bedingung, dass (x-\sigma\epsilon_\theta(x,\sigma)) (\mathrm{proj}_{\mathcal{K}}(x)) gut approximiert
- Es gilt, wenn (\sqrt{n}\sigma) (\mathrm{dist}_{\mathcal{K}}(x)) bis auf einen konstanten Faktor gut schätzt
- Es wird angenommen, dass der Fehler auf höchstens (\eta\mathrm{dist}_{\mathcal{K}}(x)) begrenzt ist
- Bei niedrigem Rauschen ist unter der Manifold Hypothesis das meiste zusätzliche Rauschen orthogonal zur Datenmannigfaltigkeit, sodass Denoising die Projektion approximiert
- Bei hohem Rauschen hat selbst ein Denoiser, der den gewichteten Mittelwert der Daten vorhersagt, einen kleinen relativen Fehler, wenn (\sigma) größer ist als der Durchmesser von (\mathcal{K})
- CIFAR-10 ist klein genug, um den idealen Denoiser zu berechnen; in Experimenten ist der relative Fehler zwischen der exakten Projektion entlang der Sampling-Trajektorie und der Ausgabe des idealen Denoisers klein
Sampling: iteratives Denoising und DDIM
- Hat man einen gelernten Denoiser, wird bei verrauschtem (x_t) und Rauschlevel (\sigma_t) mit (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t)) das (x_0) vorhergesagt
- Der Startpunkt wird so gewählt, dass (\sigma_T) im Vergleich zum Durchmesser von (\mathcal{K}) groß ist, und (x_T) wird unabhängig aus (N(0,\sigma_T)) gesampelt, sodass er weit von (\mathcal{K}) entfernt liegt
- Bei hohem Rauschen kann ein einzelner Denoiser-Aufruf trotz kleinem relativen Fehler einen großen absoluten Fehler haben; die Vorhersage des idealen Denoisers liegt nahe am Datenmittelwert
- Deshalb ruft das Sampling den Denoiser wiederholt entlang eines (\sigma_t)-Schedules auf und erzeugt eine Sequenz (x_T,\ldots,x_0)
- Das Update (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) entspricht nach einer Koordinatentransformation dem deterministischen DDIM-Sampling-Algorithmus
- Der Beweis der Äquivalenz zu DDIM steht in Appendix A des Papers
DDIM als Distanzminimierung
- DDIM wird als approximativer Gradientenabstieg für (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) interpretiert
- Die Schrittweite ist (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) wird durch (\epsilon_\theta(x_t,\sigma_t)) geschätzt
- Der (\sigma_t)-Schedule bestimmt beim Sampling die Anzahl und Größe der Gradientenschritte
- Sind es zu wenige Schritte, verringert sich (\mathrm{dist}_{\mathcal{K}}(x_t)) möglicherweise nicht, sodass keine Konvergenz eintritt
- Viele kleine Schritte erhöhen die Zahl der Denoiser-Auswertungen und damit die Rechenkosten
- Ein admissible Schedule ist ein Schedule, bei dem in jeder Iteration (\sqrt{n}\sigma_t) bis auf einen konstanten Faktor zu (\mathrm{dist}_{\mathcal{K}}(x_t)) passt
- Eine geometrisch abnehmende log-lineare (\sigma_t)-Sequenz ist ein admissible Schedule
- Dem Satz zufolge gilt: Wenn für die durch DDIM erzeugten (x_t) der Gradient (\nabla\mathrm{dist}{\mathcal{K}}(x)) existiert und (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T) ist, dann wird (x_t) durch Gradientenabstieg auf der quadratischen Distanzfunktion erzeugt, und (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) bleibt erhalten
- Im Toy-Beispiel wird ein DDIM-Sampler mit 20 Schritten implementiert, indem aus dem ursprünglichen log-linearen Schedule subsampelt wird; die meisten Samples liegen nahe an den Originaldaten, aber es bleibt Raum für Verbesserungen
Verbesserter Sampler auf Basis von Gradientenschätzung
- Ausgenutzt wird, dass (\nabla\mathrm{dist}{\mathcal{K}}(x)) zwischen (x) und (\mathrm{proj}{\mathcal{K}}(x)) invariant ist; dafür wird ein Update verwendet, das die aktuelle und die vorherige Schätzung mischt
- Das Update (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) korrigiert den Fehler des vorherigen Schritts mit der aktuellen Schätzung
- In Samples des Toy-Modells konvergiert diese Methode schneller als DDIM, und die Samples liegen näher an den Originaldaten
- Im Vergleich zu DDIM lässt sich dieser Sampler als Erweiterung um Momentum interpretieren; die Trajektorie kann overshooten, konvergiert aber möglicherweise schneller
- Fügt man während der Generierung Rauschen hinzu, verbessert sich die Sampling-Qualität empirisch
- Um den ursprünglichen (\sigma_t)-Schedule beizubehalten, denoist man zunächst bis zu einem kleineren (\sigma_{t'}) und fügt dann wieder Rauschen (w_t\sim N(0,I)) hinzu
- Für (\mu=\frac{1}{2}) wird der DDPM-Sampler exakt wiederhergestellt
- Das vollständige Update (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) verallgemeinert drei Sampler
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - Gradientenschätzungs-Sampler:
gam=2, mu=0
- DDIM:
Größere Modelle und Referenzen
- Der obige Trainingscode kann nicht nur für Toy-Daten verwendet werden, sondern auch, um Bild-Diffusionsmodelle von Grund auf zu trainieren
- Das FashionMNIST-Beispiel wird als Beispiel bereitgestellt, das auf dem FashionMNIST-Dataset trainiert und nach FID den zweiten Platz im Papers with Code Leaderboard erreicht
- Der Sampling-Code kann unverändert auch für vortrainierte Latent-Diffusion-Modelle verwendet werden
- Das Beispiel nutzt
ScheduleLDM(1000)undModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - Als Textbedingung wird
An astronaut riding a horsegesetzt; nach Sampling mit 50 (\sigma)-Schritten wird das Latent dekodiert
- Das Beispiel nutzt
- Der Effekt des (\gamma)-Momentum-Terms wird in einer visuellen Gegenüberstellung bei hochauflösender Text-zu-Bild-Generierung gezeigt
- Weitere lesenswerte Materialien
- What are diffusion models: Einführung in Diffusionsmodelle aus der diskreten Zeitperspektive, die einen Markov process umkehrt
- Generative modeling by estimating gradients of the data distribution: Einführung in Diffusionsmodelle aus der kontinuierlichen Zeitperspektive, die stochastische Differentialgleichungen umkehrt
- The annotated diffusion model: ausführliche Erklärung einer PyTorch-Implementierung eines Diffusionsmodells
1 Kommentare
Hacker-News-Kommentare
Wenn es Fragen gibt, kann ich sie beantworten.
Besonders gut fand ich die Diskussion der Trajektorien, weil sie motiviert, den Teil zu verstehen, mit dem viele bei Themen wie Schedulern Schwierigkeiten haben. Er ist zwar nicht so vollständig wie die Beiträge von Song oder Lilian, aber deutlich zugänglicher, daher werde ich ihn anderen empfehlen.
Als Hinweis: Ein Freund hat früher eine Minimalimplementierung von Diffusion geschrieben, die aus DDPM-Perspektive etwas „vollständiger“ ist und nützlich war: https://github.com/VSehwag/minimal-diffusion/
Da ich selbst ein wenig mit dem Sampling-Verfahren in Stable Diffusion experimentiert habe, hätte ich auch gern einen Vergleich der Konvergenzzeit und Schrittzahl gegenüber DDIM gesehen. Mich interessiert, ob es einen Zusammenhang zwischen Momentum, Konvergenz und Fehler gibt. Zum Beispiel wäre ein Vergleich interessant, ob ein Momentum-Sampler mit 16 Schritten in etwa DDIM mit 20 Schritten ± Fehlerterm entspricht.
get_sigma_embeds(batches, sigma)scheint den ersten Input nicht zu verwenden. Ich frage mich, ob die Absicht war,sigmaauf die Form(batches, 1)zu broadcasten.Er geht viel tiefer auf die mathematischen Details ein und enthält zugleich eine sehr verständliche Minimalimplementierung mit weniger als 500 Zeilen.
Es wäre schön, wenn das auch auf Versionen mit Diffusion-Transformern ausgeweitet würde, wie sie Sora und andere Videogenerierungsmodelle antreiben. Aus diesem Beitrag und https://jaykmody.com/blog/gpt-from-scratch/ ließe sich wohl ein Einführungsartikel „Diffusion Transformer from scratch“ machen.
Wenn man dagegen wirklich tief einsteigen will, empfehle ich die Arbeiten von Kingma, Gao, Ricky Tian Qi Chen sowie den Schülern von Max Welling (Tomczak als Postdoc, Hoogeboom usw.) und dem unterschätzten Aapo Hyvärinen zu lesen. Ein Beispiel aus der vergleichsweise leichtgewichtigeren Arbeit von Kingma & Gao, das auch mit dem SD3-Paper zusammenhängt, ist hier: https://arxiv.org/abs/2303.00848
Der Nachteil ist, dass die Abhängigkeit davon, frühere Arbeiten zu kennen und zu verstehen, groß ist und dadurch die Zugänglichkeit sinkt. Es ist aber schwer, das als sinnvolle Kritik zu bezeichnen, denn es handelt sich um Forschung und nicht um Lehrmaterial für die breite Öffentlichkeit.
n_embdprojizieren, und der Diffusionsprozess selbst kann unverändert bleiben.[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
Aus unserer Sicht sind Diffusionsmodelle leicht zu trainieren, weil sie als Trainingsziel nicht den Gradienten der exakten Distanzfunktion vorhersagen, sondern den Gradienten einer geglätteten Distanzfunktion. Sampling mit Diffusionsmodellen ist ähnlich wie das Ausführen mehrerer angenäherter Gradientenschritte.
Wer Diffusionsmodelle tiefer verstehen möchte, dem empfehle ich, all diese Blogbeiträge zu lesen und die unterschiedlichen Interpretationen kennenzulernen.
Allerdings scheint der Ansatz dieses Beitrags interessantere Experimente zu ermöglichen, etwa zur Fehleranalyse des Denoisers.
[1] https://arxiv.org/pdf/2305.03486.pdf
Warum fällt es einem Bildgenerator zum Beispiel schwer, Klaviertasten zu erzeugen? Um die Struktur zu erzeugen, bei der sich schwarze Tasten in Zweier- und Dreiergruppen abwechseln, müsste er anscheinend Zwischen-Distanzbeschränkungen besser darstellen.