KL Verlust mit einer Einheit Gauß


10

Ich habe eine VAE implementiert und zwei verschiedene Online-Implementierungen der vereinfachten univariaten Gaußschen KL-Divergenz festgestellt. Die ursprüngliche Abweichung gemäß hier ist Wenn wir annehmen, dass unser Prior eine Einheit Gauß'sche ist, dh und , vereinfacht sich dies bis hinunter zu Und hier liegt meine Verwirrung. Obwohl ich mit der obigen Implementierung einige obskure Github-Repos gefunden habe, wird sie häufiger verwendet:

KLloss=log(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1
KLloss=log(σ1)+σ12+μ12212
KLloss=12(2log(σ1)σ12μ12+1)

=12(log(σ1)σ1μ12+1)
Zum Beispiel im offiziellen Keras-Autoencoder-Tutorial . Meine Frage ist dann, was fehlt mir zwischen diesen beiden? Der Hauptunterschied besteht darin, den Faktor 2 auf den logarithmischen Term zu setzen und die Varianz nicht zu quadrieren. Analytisch habe ich letzteres mit Erfolg eingesetzt, für was es wert ist. Vielen Dank im Voraus für jede Hilfe!

Antworten:


7

Beachten Sie, dass Sie durch Ersetzen von durch in der letzten Gleichung die vorherige wiederherstellen (dh ). Dies lässt mich denken, dass im ersten Fall der Encoder zur Vorhersage der Varianz verwendet wird, während er im zweiten Fall zur Vorhersage der Standardabweichung verwendet wird.σ1σ12log(σ1)σ12log(σ1)σ12

Beide Formulierungen sind äquivalent und das Ziel bleibt unverändert.


Ich denke nicht, dass dies der Fall sein kann. Ja, beide werden minimiert, wenn für Null und Einheit . In der ursprünglichen Gleichung (mit der Varianz) ist die Strafe für das Entfernen von von der Einheit jedoch weitaus größer als in der zweiten Gleichung (basierend auf der Standardabweichung). Die Strafe für Variationen in ist für beide gleich, und der Rekonstruktionsfehler wäre der gleiche, so dass die Verwendung der zweiten Version die relative Bedeutung von Abweichungen von von der Einheit dramatisch ändert . Was vermisse ich? μσσμσ
TheBamf

0

Ich glaube, die Antwort ist einfacher. In VAE verwenden Menschen normalerweise eine multivariate Normalverteilung, die eine Kovarianzmatrix anstelle von Varianz . Das sieht in einem Code verwirrend aus, hat aber die gewünschte Form.Σσ2

Hier finden Sie die Ableitung einer KL-Divergenz für multivariate Normalverteilungen: Ableitung des KL-Divergenzverlusts für VAEs

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.