Hinweis: Den Code hinter dieser Antwort finden Sie hier .
Angenommen, wir haben einige Daten aus zwei verschiedenen Gruppen, rot und blau:
Hier können wir sehen, welcher Datenpunkt zur roten oder blauen Gruppe gehört. Dies macht es einfach, die Parameter zu finden, die jede Gruppe charakterisieren. Zum Beispiel liegt der Mittelwert der roten Gruppe bei 3, der Mittelwert der blauen Gruppe bei 7 (und wir könnten die genauen Mittelwerte finden, wenn wir wollten).
Dies wird allgemein als Maximum-Likelihood-Schätzung bezeichnet . Bei einigen Daten berechnen wir den Wert eines Parameters (oder von Parametern), der diese Daten am besten erklärt.
Stellen Sie sich nun vor, wir können nicht sehen, welcher Wert aus welcher Gruppe entnommen wurde. Für uns sieht alles lila aus:
Hier haben wir das Wissen, dass es zwei Gruppen von Werten gibt, aber wir wissen nicht, zu welcher Gruppe ein bestimmter Wert gehört.
Können wir noch die Mittelwerte für die rote und die blaue Gruppe abschätzen, die am besten zu diesen Daten passen?
Ja, oft können wir! Die Erwartungsmaximierung gibt uns eine Möglichkeit, dies zu tun. Die sehr allgemeine Idee hinter dem Algorithmus ist folgende:
- Beginnen Sie mit einer ersten Schätzung der einzelnen Parameter.
- Berechnen Sie die Wahrscheinlichkeit, dass jeder Parameter den Datenpunkt erzeugt.
- Berechnen Sie die Gewichte für jeden Datenpunkt und geben Sie an, ob er roter oder blauer ist, basierend auf der Wahrscheinlichkeit, dass er von einem Parameter erzeugt wird. Kombinieren Sie die Gewichte mit den Daten ( Erwartung ).
- Berechnen Sie anhand der gewichtsangepassten Daten eine bessere Schätzung für die Parameter ( Maximierung ).
- Wiederholen Sie die Schritte 2 bis 4, bis die Parameterschätzung konvergiert (der Prozess erzeugt keine andere Schätzung mehr).
Diese Schritte bedürfen einer weiteren Erläuterung, daher gehe ich auf das oben beschriebene Problem ein.
Beispiel: Schätzung von Mittelwert und Standardabweichung
In diesem Beispiel werde ich Python verwenden, aber der Code sollte ziemlich leicht zu verstehen sein, wenn Sie mit dieser Sprache nicht vertraut sind.
Angenommen, wir haben zwei Gruppen, rot und blau, wobei die Werte wie im obigen Bild verteilt sind. Insbesondere enthält jede Gruppe einen Wert aus einer Normalverteilung mit den folgenden Parametern:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
Hier ist noch einmal ein Bild dieser roten und blauen Gruppen (damit Sie nicht nach oben scrollen müssen):
Wenn wir die Farbe jedes Punktes sehen können (dh zu welcher Gruppe er gehört), ist es sehr einfach, den Mittelwert und die Standardabweichung für jede Gruppe zu schätzen. Wir übergeben nur die roten und blauen Werte an die in NumPy integrierten Funktionen. Beispielsweise:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Aber was ist, wenn wir die Farben der Punkte nicht sehen können ? Das heißt, anstelle von Rot oder Blau wurde jeder Punkt lila gefärbt.
Um zu versuchen, die Mittelwert- und Standardabweichungsparameter für die roten und blauen Gruppen wiederherzustellen, können wir die Erwartungsmaximierung verwenden.
Unser erster Schritt ( Schritt 1 oben) besteht darin, die Parameterwerte für den Mittelwert und die Standardabweichung jeder Gruppe zu erraten. Wir müssen nicht intelligent raten; Wir können beliebige Zahlen auswählen:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Diese Parameterschätzungen erzeugen Glockenkurven, die folgendermaßen aussehen:
Das sind schlechte Schätzungen. Beide Mittel (die vertikalen gepunkteten Linien) sehen zum Beispiel für sinnvolle Punktgruppen weit entfernt von jeder Art von "Mitte" aus. Wir wollen diese Schätzungen verbessern.
Der nächste Schritt ( Schritt 2 ) besteht darin, die Wahrscheinlichkeit zu berechnen, mit der jeder Datenpunkt unter den aktuellen Parameterschätzungen erscheint:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Hier haben wir einfach jeden Datenpunkt in die Wahrscheinlichkeitsdichtefunktion für eine Normalverteilung eingefügt, wobei wir unsere aktuellen Schätzungen zum Mittelwert und zur Standardabweichung für Rot und Blau verwenden. Dies sagt uns zum Beispiel, dass nach unseren derzeitigen Schätzungen der Datenpunkt bei 1,761 viel wahrscheinlicher rot (0,189) als blau (0,00003) ist.
Für jeden Datenpunkt können wir diese beiden Wahrscheinlichkeitswerte in Gewichte umwandeln ( Schritt 3 ), sodass sie wie folgt zu 1 summieren:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Mit unseren aktuellen Schätzungen und unseren neu berechneten Gewichten können wir jetzt neue Schätzungen für den Mittelwert und die Standardabweichung der roten und blauen Gruppe berechnen ( Schritt 4 ).
Wir berechnen zweimal den Mittelwert und die Standardabweichung unter Verwendung aller Datenpunkte, jedoch mit unterschiedlichen Gewichtungen: einmal für die roten Gewichte und einmal für die blauen Gewichte.
Das Schlüsselelement der Intuition ist, dass je größer das Gewicht einer Farbe auf einem Datenpunkt ist, desto stärker beeinflusst der Datenpunkt die nächsten Schätzungen für die Parameter dieser Farbe. Dies hat den Effekt, dass die Parameter in die richtige Richtung "gezogen" werden.
def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.
This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Wir haben neue Schätzungen für die Parameter. Um sie wieder zu verbessern, können wir zu Schritt 2 zurückkehren und den Vorgang wiederholen. Wir tun dies, bis die Schätzungen konvergieren oder nachdem eine bestimmte Anzahl von Iterationen durchgeführt wurde ( Schritt 5 ).
Für unsere Daten sehen die ersten fünf Iterationen dieses Prozesses folgendermaßen aus (neuere Iterationen sehen stärker aus):
Wir sehen, dass die Mittelwerte bereits bei einigen Werten konvergieren und auch die Formen der Kurven (bestimmt durch die Standardabweichung) stabiler werden.
Wenn wir 20 Iterationen fortsetzen, erhalten wir Folgendes:
Der EM-Prozess hat sich den folgenden Werten angenähert, die den tatsächlichen Werten sehr nahe kommen (wo wir die Farben sehen können - keine versteckten Variablen):
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
Im obigen Code haben Sie möglicherweise bemerkt, dass die neue Schätzung für die Standardabweichung unter Verwendung der Schätzung der vorherigen Iteration für den Mittelwert berechnet wurde. Letztendlich spielt es keine Rolle, ob wir zuerst einen neuen Wert für den Mittelwert berechnen, da wir nur die (gewichtete) Varianz der Werte um einen zentralen Punkt herum finden. Die Schätzungen für die Parameter werden weiterhin konvergieren.