1 Punkte von GN⁺ 2024-07-14 | Noch keine Kommentare. | Auf WhatsApp teilen
  • AlphaFold3 versucht, über einzelne Proteine hinaus Komplexe aus Proteinen, Nukleinsäuren und kleinen Molekülen allein anhand der Sequenz vorherzusagen; dadurch werden Eingaberepräsentation und Tokenisierung deutlich komplexer als bei AF2
  • Die Eingaben gliedern sich in single/pair-Repräsentationen auf Token-Ebene, Repräsentationen auf Atomebene, MSA und Templates; Standardaminosäuren und -nukleotide werden als 1 Token behandelt, nichtstandardmäßige Reste und andere Moleküle als 1 Token pro Atom
  • Der Representation-Learning-Trunk verbessert über Template-Modul, MSA-Modul und Pairformer die single-Repräsentation s und die pair-Repräsentation z wiederholt mittels pair-bias attention, triangle-Operationen und recycling
  • Die Strukturvorhersage verwendet statt AF2s Invariant Point Attention ein bedingtes Diffusionsmodell für Atomkoordinaten und erzeugt durch Rotations-/Translations-Augmentierung und Denoising Koordinaten-Updates für alle Atome
  • Das Training kombiniert Distogram-, Diffusion- und Confidence-Loss und lernt durch Cross-Distillation unter Nutzung von Ergebnissen aus AF2 und AF-Multimer sogar unfolded-Repräsentationen in Bereichen mit niedriger Konfidenz neu

Eingabebereich und Gesamtpipeline von AlphaFold3

  • Ziel von AlphaFold3 ist es, nicht wie AF2 nur einzelne Proteinsequenzen vorherzusagen oder wie AF-Multimer nur Proteinkomplexe zu behandeln, sondern Strukturen aus Proteinen und optional weiteren Proteinen, Nukleinsäuren oder kleinen Molekülen allein anhand der Sequenz vorherzusagen
  • Die Bedeutung von „Token“ hängt von der Art der Eingabe ab
    • Protein: 1 Standardaminosäure ist 1 Token
    • DNA/RNA: 1 Standardnukleotid ist 1 Token
    • Nichtstandardmäßige Aminosäuren/Nukleotide: 1 Atom ist 1 Token
    • Andere Moleküle: 1 Atom ist 1 Token
  • Ein Protein aus 35 Standardaminosäuren kann tatsächlich mehr als 600 Atome enthalten, wird aber als 35 Token dargestellt; ein Ligand mit 35 Atomen wird als 35 Token dargestellt
  • Das Modell besteht grob aus drei Phasen
    • Input Preparation: Umwandlung der vom Nutzer eingegebenen Sequenzen sowie gefundener verwandter Sequenzen und Strukturen in numerische Tensoren
    • Representation Learning: Aktualisierung von single- und pair-Repräsentationen mit mehreren Attention-Varianten
    • Structure Prediction: Vorhersage der Struktur per bedingter Diffusion
  • Proteinkomplexe werden hauptsächlich in zwei Repräsentationen gespeichert
    • single representation: repräsentiert alle Token des Komplexes selbst
    • pair representation: repräsentiert Beziehungen wie Abstände und potenzielle Interaktionen zwischen allen Token-Paaren
  • Die wichtigsten Kanaldimensionen sind c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384

Eingabevorbereitung: Wie Sequenzen in 6 Tensoren umgewandelt werden

  • Die vom Nutzer bereitgestellte Eingabe wird in 6 Tensoren umgewandelt, die in den Modell-Trunk gehen
    • s: token-level single representation
    • z: token-level pair representation
    • q: atom-level single representation
    • p: atom-level pair representation
    • m: MSA representation
    • t: template representation
  • MSA- und Template-Suche

    • AF3 sucht für Protein- und RNA-Sequenzen nach ähnlichen Sequenzen und organisiert sie als MSA; verwandte Strukturen werden als template einbezogen
    • MSA richtet ähnliche Proteinsequenzen aus, die in verschiedenen Spezies gefunden wurden, und liefert dem Modell Erhaltungsmuster an bestimmten Positionen sowie Korrelationen von Veränderungen zwischen unterschiedlichen Positionen
    • Bekannte Strukturen ähnlicher Proteine werden wie beim Homology Modeling genutzt, um die Struktur des Query-Proteins abzuschätzen
    • Die Suche umfasst kein Training; es werden HMM-basierte Methoden verwendet
    • Mit jackhmmer, HHBlits und nhmmer werden mehrere Protein- und RNA-Datenbanken durchsucht, und mit hmmsearch werden ähnliche Sequenzen in der Protein Data Bank gefunden
    • Die MSA-Größe ist wegen der Rechenkomplexität auf N_MSA < 2^14 begrenzt
    • Für jede Protein-chain werden hochwertige Strukturen ausgewählt, und bis zu 4 davon als template gesampelt
    • Das im Vergleich zu AF-Multimer neu hinzugefügte Suchelement ist, dass auch RNA-Sequenzen als Suchziel einbezogen werden
  • Template-Repräsentation

    • Aus der 3D-Struktur des template wird der euklidische Abstand zwischen jedem Token-Paar berechnet
    • Für Token mit mehreren Atomen wird ein repräsentatives „center atom“ verwendet
      • Aminosäure: -Atom
      • Standardnukleotid: C1'-Atom
    • Die Abstandswerte werden nicht als kontinuierliche Werte, sondern als distogram diskretisiert
      • 38 Bins von 3,15 Å bis 50,75 Å
      • 1 zusätzlicher Bin für größere Abstände
    • Dem distogram werden Chain-Informationen, Angaben dazu, ob das betreffende Token in der crystal structure resolved ist, sowie lokale Abstandsinformationen innerhalb jeder Aminosäure hinzugefügt
    • Die template matrix wird so maskiert, dass nur Abstände innerhalb derselben Chain betrachtet werden; es wird nicht versucht, über die Template-Auswahl Informationen zu inter-chain interactions zu gewinnen

Darstellung auf Atomebene und Atom Transformer

  • Reference Conformer und atom-level Darstellung

    • Um die atom-level Single-Darstellung q zu erstellen, wird für jede Aminosäure, jedes Nukleotid und jeden Liganden ein Reference Conformer berechnet
    • Ein Conformer ist eine 3D-Anordnung der Atome eines Moleküls, die durch Sampling von Rotationen um Einfachbindungen erzeugt wird
    • Standard-Aminosäuren verwenden einen energiearmen Conformer, der per Lookup verfügbar ist; für kleine Moleküle wird mit RDKit’s ETKDGv3 ein 3D-Conformer erzeugt
    • Durch Kombination der relativen Positionen, Atomladungen, Ordnungszahlen, Identifikatoren usw. des Conformers entsteht die atom-level Single Representation c
    • Mit c wird die atom-level Pair Representation p initialisiert; eine Maske v sorgt dafür, dass nur die im Reference Conformer berechneten Abstände zwischen Atomen enthalten sind
    • q startet als Kopie von c und wird anschließend im Atom Transformer aktualisiert
  • Rolle des Atom Transformers

    • Der Atom Transformer ist ein Modul, das Attention auf Atomebene ausführt und q mithilfe von p und der ursprünglichen Darstellung c aktualisiert
    • c wird nicht aktualisiert, sondern wie eine Residual Connection zur Ausgangsdarstellung verwendet
    • Die Grundstruktur ähnelt einem Transformer und enthält LayerNorm, Attention und MLP Transition, die einzelnen Schritte werden jedoch durch zusätzliche Eingaben aus c und p angepasst
  • Adaptive LayerNorm

    • Adaptive LayerNorm erzeugt gamma und beta aus einer Hilfseingabe, statt feste gamma- und beta-Werte zu lernen
    • Im Atom Transformer ist q das Ziel der Reskalierung, und die Reskalierungsparameter werden aus der Hilfseingabe c vorhergesagt
  • Attention with Pair Bias

    • Atom-level Attention with Pair Bias ist eine Erweiterung von Self-Attention
    • Query, Key und Value stammen alle aus der Single Representation q, aber nach dem Query-Key-Dot-Product wird eine lineare Projection der Pair Representation p als Bias addiert
    • Informationen fließen von der Pair Representation nach q, aber in diesem Schritt wird p nicht mit Informationen aus q aktualisiert
    • Ein Gate, das durch Anwenden einer Sigmoid-Funktion auf eine zusätzliche Projection entsteht, wird mit dem Attention-Ergebnis multipliziert und steuert, welche Informationen im Residual Stream verbleiben
    • Da die Zahl der Atome deutlich größer sein kann als die Zahl der Tokens, wird statt Full Attention Sequence-local atom attention verwendet
    • Eine Local Group aus 32 Atomen kann auf 128 andere Atome attendieren
  • Conditioned Gating und Transition

    • Conditioned Gating wendet ein Gate auf die Daten an, das aus der ursprünglichen atom-level Single Matrix c erzeugt wurde
    • Conditioned Transition entspricht dem MLP eines Transformers und heißt conditioned, weil Adaptive LayerNorm und Conditional Gating von c abhängen
    • AF3 verwendet im Transition Block SwiGLU statt ReLU
    • Die ReLU-basierte Transition in AF2 hat die Struktur 4-fache Up-Projection, ReLU und Down-Projection
    • SwiGLU in AF3 wendet eine Swish-Nichtlinearität auf eine von zwei Up-Projections an, multipliziert das Ergebnis und führt anschließend eine Down-Projection aus

Atomdarstellungen zu Token-Darstellungen aggregieren

  • Da die Phase des Representation Learning anschließend auf token-level arbeitet, werden die atom-level Darstellungen zu token-level Darstellungen aggregiert
  • Die atom-level Representation wird auf eine größere Dimension projiziert, anschließend wird der Mittelwert der Atome gebildet, die zum selben Token gehören
  • Diese Mittelwert-Aggregation wird angewendet, wenn wie bei Standard-Aminosäuren und Nukleotiden mehrere Atome mit einem Token verbunden sind; Eingaben mit 1 Token pro Atom bleiben unverändert
  • Mit der token-level Single-Eingabe werden außerdem Statistiken aus dem MSA kombiniert
    • Aminosäuretyp
    • MSA-Aminosäureverteilung an dieser Position
    • Deletion Mean des jeweiligen Tokens
  • Bei Tokens ohne MSA, etwa Ligandenatomen, sind diese Werte 0
  • Die so erstellten s_inputs werden durch eine Projection zu s_init und in der Phase des Representation Learning aktualisiert
  • Die Pair Representation z_init ist ein dreidimensionaler Tensor, der Beziehungen für jedes Token-Paar speichert; jedes z_i,j ist ein Vektor mit c_z=128 Dimensionen
  • Zur Initialisierung von z_i,j werden Projections von s_i und s_j, Relative Positional Encoding sowie vom Nutzer angegebene Bond-Informationen zwischen Tokens addiert

Representation Learning: Template, MSA, Pairformer

  • Representation Learning ist der Trunk, der den Großteil der Modellberechnung ausmacht; Ziel ist es, die token-level Single-Darstellung s und die Pair-Darstellung z zu verbessern
  • Single Sequence Representation meint nicht nur eine einzelne Proteinsequenz, sondern die aneinandergereihte Sequenz aller Atome oder Tokens in der Struktur
  • Template Module

    • Jedes Template durchläuft eine lineare Projection und wird mit einer linearen Projection der Pair Representation z addiert
    • Die kombinierte Matrix durchläuft den Pairformer Stack
    • Die Ergebnisse mehrerer Templates werden gemittelt und anschließend erneut durch eine lineare Layer geführt
    • In der letzten linearen Layer wird ReLU verwendet; dies ist eine der seltenen Stellen in AF3, an denen ReLU als Nichtlinearität eingesetzt wird
  • MSA Module

    • Das MSA Module ist dem Evoformer aus AF2 sehr ähnlich und verbessert die MSA Representation m und die Pair Representation z gleichzeitig
    • Es werden nicht alle MSA Rows vollständig verwendet, sondern per Subsampling ausgewählt; anschließend wird eine Projection der Single Representation zum MSA addiert
    • Outer Product Mean ist eine Operation, die MSA-Informationen in die Pair Representation einbringt
      • Für jeden Token-Index i,j wird für alle Evolutionary Sequences das Outer Product von m_s,i und m_s,j berechnet
      • Dieses wird über die gesamte Sequenz gemittelt, geflattet, projiziert und zu z_i,j addiert
      • Dies ist die einzige Stelle im Modell, an der Informationen zwischen Evolutionary Sequences geteilt werden
    • Row-wise gated self-attention using only pair bias aktualisiert das MSA mithilfe der Pair Representation
      • Statt mit Query und Key einen Attention Score zu erzeugen, wird die Pair Representation z in eine Matrix projiziert und als Attention Score zwischen Tokens verwendet
      • Da dies unabhängig auf jede MSA Row angewendet wird, werden in diesem Schritt keine Informationen zwischen Evolutionary Sequences geteilt
    • Am Ende des MSA Module wird die Pair Representation erneut durch Triangle Update und Triangle Attention aktualisiert

Pairformer und Triangle-Operationen

  • Nachdem z mit Template und MSA aktualisiert wurde, werden Template und MSA nicht weiter verwendet; nur s und z werden in den Pairformer eingespeist
  • Der Pairformer erzeugt durch die Wiederholung von 48 Blöcken das finale s_trunk und z_trunk
  • Intuition der Triangle-Operationen

    • Triangle Update und Triangle Attention sind Strukturen, die versuchen, die Intuition der Dreiecksungleichung im Modell abzubilden
    • z_i,j des Pair-Tensors ist zwar nicht der physikalische Abstand selbst, enthält aber die Beziehung zwischen den Tokens i und j; daher werden die drei Beziehungen i-j, j-k und i-k so aktualisiert, dass sie zueinander konsistent sind
    • Die Dreiecksungleichung wird im Modell nicht direkt erzwungen, sondern dadurch induziert, dass für alle Triplets (i,j,k) z_i,j aktualisiert wird
    • z lässt sich wie eine directed adjacency matrix betrachten, daher werden die Richtungen outgoing edge und incoming edge getrennt verarbeitet
  • Triangle Updates

    • Beim outgoing update wird jedes z_i,j mithilfe eines anderen Elements derselben Zeile, z_i,k, und der dritten Edge z_j,k aktualisiert
    • In der Implementierung werden drei Projektionen a, b, g von z erzeugt; die elementweise Multiplikation von Zeile i und Zeile j wird über k aufsummiert, anschließend wird das Gate g angewendet
    • Das incoming update ist die Variante mit vertauschten Zeilen und Spalten: z_i,j wird über ein anderes Element derselben Spalte, z_k,j, und z_k,i aktualisiert
  • Triangle Attention

    • Triangle Attention ist eine Form, die das Triangle-Prinzip zur axial attention hinzufügt, bei der unabhängige Attention auf Zeilen und Spalten einer 2D-Matrix angewendet wird
    • Im Fall „starting node“ wird z_j,k als Bias zur Query-Key-Übereinstimmung von z_i,j und z_i,k addiert
    • Im Fall „ending node“ arbeitet sie spaltenbasiert und biasiert den Attention-Score von z_i,j und z_k,i mit z_k,j
  • Single Attention with Pair Bias

    • Nach dem Triangle-Schritt und dem Transition-Block wird die Single Representation s durch single attention with pair bias aktualisiert, die die aktualisierte Pair Representation z verwendet
    • Da sie auf Token-Ebene arbeitet, wird nicht die auf Atom-Ebene verwendete blockweise Sparse Attention genutzt, sondern Full Attention

Strukturvorhersage: Denoising von Atomkoordinaten per Diffusion

  • Grundprinzip des Diffusion-Modells

    • AF3 führt die finale Strukturvorhersage als atom-level diffusion durch
    • Ein Diffusion Model fügt realen Daten schrittweise Random Noise hinzu und wird darauf trainiert, vorherzusagen, welches Noise hinzugefügt wurde
    • Bei der Inferenz beginnt es mit vollständigem Random Noise und erzeugt einen denoised datapoint, indem es in jedem Schritt das vom Modell vorhergesagte Noise entfernt
    • Conditional Diffusion nimmt die aktuelle noisy generation, die aktuelle Timestep-Repräsentation und einen Conditioning-Vektor als Eingabe und erzeugt ein Ergebnis, das zur Bedingung passt
    • Das Denoising-Ziel in AF3 ist die Matrix x, die die x,y,z-Koordinaten aller Atome enthält
  • Rotations- und Translations-Augmentierung statt IPA aus AF2

    • AF3 verwendet nicht AF2s Invariant Point Attention, sondern rotiert und verschiebt bei jedem Timestep den gesamten gerade vorhergesagten Komplex zufällig
    • Diese Augmentierung lässt das Modell lernen, dass jede Rotation und Translation als dieselbe Struktur gültig ist, und ist ein einfacherer Ansatz als AF2s IPA
    • Die Rotation wird um den Mittelwert aller Atomkoordinaten der aktuellen Generation angewendet, und die Translation wird in jeder Dimension aus einer N(0,1)-Gauss-Verteilung gesampelt
    • Den Koordinaten wird außerdem geringes Noise hinzugefügt, um vielfältigere Generations zu fördern
    • Bei der Inferenz können mehrere Generations mit einem Confidence Head bewertet und die Generation mit dem höchsten Score zurückgegeben werden
  • Die vier Phasen des Diffusion Module

    • Jeder Denoising-Schritt verwendet mehrere Conditioning Representations
      • Trunk-Ausgaben s_trunk, z_trunk
      • Initiale Repräsentationen s_inputs, c_inputs aus dem Input Embedder
    • Der Diffusion-Prozess wechselt zwischen Token- und Atom-Raum und besteht aus vier Phasen
        1. Token-Level-Conditioning-Tensor vorbereiten
        1. Atom-Level-Conditioning-Tensor vorbereiten, Atom Transformer anwenden, auf Token-Level aggregieren
        1. Token-Level-Attention anwenden
        1. Mit Atom-Level-Attention das atomweise Noise Update vorhersagen
    • Beim Token-Level-Conditioning wird z_trunk mit Relative Positional Encoding kombiniert und durch einen Transition-Block geführt
    • Zur Single Representation werden s_inputs und s_trunk kombiniert; anschließend wird ein Fourier Embedding entsprechend dem Diffusion Timestep hinzugefügt
    • In der Atom-Level-Phase werden die initialen c und p mit der aktuellen Token-Level Representation aktualisiert, und die aktuellen Koordinaten x werden mit der Data Variance skaliert, um dimensionlose Koordinaten r zu erzeugen
    • In der letzten Atom-Level-Phase mappt ein Linear Layer q auf R^3 und erzeugt so das Coordinate Update r_update für alle Atome
    • Das Update wird unter Berücksichtigung von Data Variance und Noise Schedule wieder zu x_update reskaliert und anschließend auf die aktuellen Koordinaten x_l angewendet

Verlustfunktion und Confidence Head

  • Der gesamte Loss ist eine gewichtete Summe aus drei Termen

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogram bewertet auf Token-Ebene die Genauigkeit des vorhergesagten Distogramms
    • Beim Erzeugen von Token-Koordinaten aus Atomkoordinaten werden die Koordinaten des Center-Atoms jedes Tokens verwendet
    • Die Distogramm-Distanz wird als categorical value behandelt; vorhergesagtes und tatsächliches Distogramm werden per Cross Entropy verglichen
  • L_diffusion

    • L_diffusion ist eine gewichtete Summe mehrerer Terme für Atompositionen
    • L_MSE berechnet den Mean Squared Error zwischen Positionen nicht nur für Center-Atoms, sondern für alle Atome; DNA-, RNA- und Liganden-Atome werden upgewichtet
    • L_bond ist ein zusätzlicher MSE-Term, der die Genauigkeit der Bond Length von Atom-Paaren in Protein-Ligand-Bonds erhöhen soll
    • In der frühen Trainingsphase gilt α_bond=0, der Term wird also erst später eingeführt
    • L_smooth_LDDT ist ein Loss, der die Local Distance Accuracy glatt und differenzierbar macht
      • Als Thresholds werden vier Werte verwendet: 4 Å, 2 Å, 1 Å und 0,5 Å
      • Nukleotid-Atompaare werden ignoriert, wenn sie weiter als 30 Å voneinander entfernt sind
      • Protein- oder Liganden-Atompaare werden ignoriert, wenn sie weiter als 15 Å voneinander entfernt sind
  • L_confidence

    • L_confidence erhöht nicht direkt die Strukturgenauigkeit, sondern trainiert das Modell darauf, die Genauigkeit seiner eigenen Vorhersage zu schätzen
    • Er besteht aus Losses, die vier Confidence-Metriken entsprechen
      • pLDDT: Local Distance Accuracy für nahe Atome
      • PAE: Predicted Alignment Error eines Token-Paars
      • PDE: Predicted Distance Error zwischen Token-Paaren
      • experimentally resolved prediction: Vorhersage, ob jedes Atom in der experimentellen Struktur resolved wurde
    • Selbst wenn eine vorhergesagte Struktur ungenau ist und der PAE hoch ausfällt, kann der entsprechende PAE-Loss niedrig sein, wenn das Modell den PAE ebenfalls hoch vorhersagt
    • Die Confidence Prediction wird in einem Zwischenschritt der Diffusion erzeugt
    • Der Gradient des Confidence Loss aktualisiert nur den Confidence-Prediction-Head und wirkt sich nicht auf den Rest des Modells aus

Zusätzliche Lernverfahren und Effizienzsteigerungen

  • Recycling

    • AF3 verwendet wie AF2 Weight Recycling
    • Statt das Modell tiefer zu machen, werden dieselben Gewichte mehrfach wiederverwendet, um die Representation schrittweise zu verbessern
    • Auch die Diffusion nutzt bei der Inference Timestep-Informationen und verwendet dieselben Gewichte bei jedem Timestep erneut, wodurch Recycling inhärent enthalten ist
  • Cross-Distillation

    • AF3 nutzt nicht nur synthetische Trainingsdaten, die es selbst erzeugt hat, sondern auch synthetische Daten von AF2 und AF-Multimer
    • Nach dem Wechsel zu diffusion-basierter Generation gab es das Problem, dass die „Spaghetti“-Form verschwand, mit der AF2 niedrig vertrauenswürdige und ungeordnete Bereiche visuell unterscheidbar machte
    • Durch Einbeziehen von AF2- und AF-Multimer-Generationen in die AF3-Trainingsdaten lernt AF3, in Bereichen, bei denen AF2 unsicher war, unfolded regions auszugeben
    • Im Distillation-Dataset werden Nukleinsäuren und kleine Moleküle entfernt, die AF2 und AF-Multimer nicht verarbeiten können
    • Nachdem das vorherige Modell eine vorhergesagte Struktur erzeugt und diese mit dem Original aligned wurde, werden die entfernten Moleküle wieder hinzugefügt
    • Wenn ein wieder hinzugefügtes Molekül einen Atom-Clash erzeugt, wird die gesamte Struktur ausgeschlossen, um zu vermeiden, dass das Modell lernt, Clashes zuzulassen
  • Cropping und Trainingsphase

    • Das Modell selbst hat keine explizite Beschränkung der Eingabesequenzlänge, aber mehrere Operationen wachsen mit N_tokens^3, wodurch Speicher- und Compute-Anforderungen steigen
    • Zur Effizienzsteigerung werden Proteine per Random Crop beschnitten
    • Da Interaktionen zwischen mehreren Chains modelliert werden müssen, muss der Crop die Chains gemeinsam umfassen
    • Es werden drei Cropping-Methoden verwendet
      • contiguous cropping: Auswahl einer zusammenhängenden Aminosäuresequenz aus jeder Chain
      • spatial cropping: Auswahl von Aminosäuren basierend auf der Distanz zu einem Referenzatom
      • spatial interface cropping: Auswahl basierend auf der Distanz zu Atomen am Binding Interface
    • Auch ein mit Random Crop 384 trainiertes Modell kann auf längere Sequenzen angewendet werden; um die Fähigkeit zur Verarbeitung längerer Sequenzen zu verbessern, wird jedoch wiederholt mit größerer Sequence Length fine-getuned
  • Clashing und Batch Size

    • Der AF3-Loss enthält keine Clash Penalty für overlapping atoms
    • Das diffusion-based Structure Module könnte theoretisch zwei Atome an derselben Position vorhersagen, nach dem Training ist dieses Problem jedoch gering
    • Beim Ranking generierter Strukturen wird eine Clashing Penalty verwendet
    • Der Diffusion Process wirkt zwar komplex, hat aber geringere Rechenkosten als der Trunk
    • Zur Effizienzsteigerung beim Training wird die Batch Size nach dem Trunk erweitert
    • Jede Input Structure durchläuft Embedding und Trunk einmal; anschließend werden 48 unabhängige data-augmented Structures parallel trainiert

AF3-Design aus ML-Perspektive

  • Eine Struktur ähnlich wie Retrieval-Augmented Generation

    • Die MSA- und Template-Suche von AF3 hat einen ähnlichen Charakter wie RAG bei Sprachmodellen
    • Im AlphaFold-Bereich wurde die Nutzung struktureller Templates schon lange vor dem Begriff RAG als Homology Modeling bezeichnet
    • AF3 hat den Anteil der MSA-Verarbeitung gegenüber AF2 reduziert, enthält aber weiterhin MSA und Templates
    • Einige Proteinvorhersagemodelle wie ESMFold entfernen Retrieval und verwenden fully parametric inference
  • Pair-Bias Attention

    • Pair-Bias Attention, ein zentraler Bestandteil von AF2, wird in AF3 breiter eingesetzt
    • Query, Key und Value stammen aus derselben Quelle, aber zur Attention Map wird ein Bias-Term aus einer anderen Quelle hinzugefügt
    • Das ist eine leichtgewichtigere Form des Informationsaustauschs als vollständige Cross-Attention
    • Da die Pair Representation der Attention Map von Natur aus ähnelt, könnte diese Struktur gut zur Proteinmodellierung passen
  • Reduzierung des Self-supervised Training

    • Modelle der ESM-Familie zeigten Stärken bei einem Ansatz, MSA-Embeddings durch Self-supervised Pre-training zu ersetzen
    • In AF2 gab es eine zusätzliche Task zur Vorhersage maskierter Tokens in der MSA, in AF3 wurde sie jedoch entfernt
    • AF3 hat den Compute-Aufwand für die MSA-Verarbeitung reduziert und verwendet kein Self-supervised Language-Modeling-Pre-training für MSA
    • Mögliche Gründe sind, dass massives Pre-training beim Compute-Einsatz ineffizient war, dass ein kleines MSA-Modul besser war als pre-trained Embeddings oder dass die Kombination aus einer hybriden Atom-Token-Struktur mit gemischten Aminosäuren, DNA/RNA und Liganden nicht gut zu pre-trained Embeddings passte
  • Mischung aus Classification und Regression

    • AF3 verwendet wie AF2 sowohl MSE als auch binned Classification Loss
    • Ein Merkmal von Classification Loss ist, dass es keinen Credit gibt, wenn man nur um einen Distogram-Bin danebenliegt – genauso wie bei einem weit entfernten Fehler
    • Die Begründung für diese Designentscheidung ist nicht klar, möglicherweise waren die Gradients aber stabiler als bei mehreren MSE-Losses
  • Elemente, die einer recurrent architecture ähneln

    • AF3 enthält viele Elemente, die eher an ein recurrent network erinnern als an einen gewöhnlichen Transformer
    • Gating steuert den Informationsfluss im Residual Stream und ähnelt den Gates von LSTM oder GRU
    • Recycling und Diffusion wenden dieselben Weights wiederholt an, um Vorhersagen schrittweise zu verbessern
    • Ähnlich wie bei adaptive compute time hängen iterative Updates mit Strukturen zusammen, die bei schwierigen Inputs mehr Verarbeitung anwenden können
    • AF2-Ablations zeigten die Bedeutung von Recycling, über die Bedeutung von Gating wurde jedoch nicht viel diskutiert

Noch keine Kommentare.

Noch keine Kommentare.