Numerisches Beispiel zum Verständnis von Expectation-Maximization


117

Ich versuche, einen guten Überblick über den EM-Algorithmus zu bekommen, um ihn implementieren und verwenden zu können. Ich verbrachte einen ganzen Tag damit, die Theorie und ein Papier zu lesen, in dem EM verwendet wird, um ein Flugzeug unter Verwendung der Positionsinformationen, die von einem Radar kommen, zu verfolgen. Ehrlich gesagt glaube ich nicht, dass ich die zugrunde liegende Idee vollständig verstehe. Kann mich jemand auf ein numerisches Beispiel hinweisen, das einige Iterationen (3-4) der EM für ein einfacheres Problem zeigt (wie das Schätzen der Parameter einer Gaußschen Verteilung oder einer Folge einer Sinusreihe oder das Anpassen einer Linie)?

Selbst wenn mich jemand auf ein Stück Code (mit synthetischen Daten) hinweisen kann, kann ich versuchen, den Code schrittweise durchzugehen.


1
k-means ist sehr em, aber mit konstanter Varianz und ist relativ einfach.
EngrStudent

2
@ arjsgh21 kannst du bitte das erwähnte papier über das flugzeug posten? Hört sich sehr interessant an. Vielen Dank
Wakan Tanka

1
Es gibt ein Online-Tutorial, das behauptet, ein sehr klares mathematisches Verständnis des Em-Algorithmus "EM Demystified: Ein Tutorial zur Expectation-Maximization" zu liefern. Das Beispiel ist jedoch so schlecht, dass es das Unverständliche einschränkt.
Shamisen Expert

Antworten:


98

Dies ist ein Rezept zum Erlernen von EM anhand eines praktischen und (meiner Meinung nach) sehr intuitiven "Coin-Toss" -Beispiels:

  1. Lesen Sie dieses kurze EM-Tutorial von Do und Batzoglou. Dies ist das Schema, in dem das Beispiel für den Münzwurf erklärt wird:

    Bildbeschreibung hier eingeben

  2. Möglicherweise haben Sie Fragezeichen im Kopf, insbesondere, woher die Wahrscheinlichkeiten im Schritt "Erwartung" stammen. Bitte beachten Sie die Erläuterungen auf dieser Seite zum Austausch von Mathematikstapeln .

  3. Schauen Sie sich diesen Code an, den ich in Python geschrieben habe und der die Lösung des Münzwurfproblems im EM-Tutorialpapier von Punkt 1 simuliert:

    import numpy as np
    import math
    import matplotlib.pyplot as plt
    
    ## E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ##
    
    def get_binomial_log_likelihood(obs,probs):
        """ Return the (log)likelihood of obs, given the probs"""
        # Binomial Distribution Log PDF
        # ln (pdf)      = Binomial Coeff * product of probabilities
        # ln[f(x|n, p)] =   comb(N,k)    * num_heads*ln(pH) + (N-num_heads) * ln(1-pH)
    
        N = sum(obs);#number of trials  
        k = obs[0] # number of heads
        binomial_coeff = math.factorial(N) / (math.factorial(N-k) * math.factorial(k))
        prod_probs = obs[0]*math.log(probs[0]) + obs[1]*math.log(1-probs[0])
        log_lik = binomial_coeff + prod_probs
    
        return log_lik
    
    # 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
    # 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
    # 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
    # 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
    # 5th:  Coin A, {THHHTHHHTH}, 7H,3T
    # so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
    
    # represent the experiments
    head_counts = np.array([5,9,8,4,7])
    tail_counts = 10-head_counts
    experiments = zip(head_counts,tail_counts)
    
    # initialise the pA(heads) and pB(heads)
    pA_heads = np.zeros(100); pA_heads[0] = 0.60
    pB_heads = np.zeros(100); pB_heads[0] = 0.50
    
    # E-M begins!
    delta = 0.001  
    j = 0 # iteration counter
    improvement = float('inf')
    while (improvement>delta):
        expectation_A = np.zeros((len(experiments),2), dtype=float) 
        expectation_B = np.zeros((len(experiments),2), dtype=float)
        for i in range(0,len(experiments)):
            e = experiments[i] # i'th experiment
              # loglikelihood of e given coin A:
            ll_A = get_binomial_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) 
              # loglikelihood of e given coin B
            ll_B = get_binomial_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) 
    
              # corresponding weight of A proportional to likelihood of A 
            weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) 
    
              # corresponding weight of B proportional to likelihood of B
            weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) 
    
            expectation_A[i] = np.dot(weightA, e) 
            expectation_B[i] = np.dot(weightB, e)
    
        pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
        pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 
    
        improvement = ( max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - 
                        np.array([pA_heads[j],pB_heads[j]]) )) )
        j = j+1
    
    plt.figure();
    plt.plot(range(0,j),pA_heads[0:j], 'r--')
    plt.plot(range(0,j),pB_heads[0:j])
    plt.show()
    

2
@ Zhubarb: Können Sie bitte die Bedingung für die Beendigung der Schleife erläutern (dh um festzustellen, wann der Algorithmus konvergiert)? Was berechnet die Variable "Verbesserung"?
Stackoverflowuser2010

@ stackoverflowuser2010, Verbesserung betrifft zwei Deltas: 1) den Wechsel zwischen pA_heads[j+1]und pA_heads[j]und 2) den Wechsel zwischen pB_heads[j+1]und pB_heads[j]. Und es dauert das Maximum der beiden Änderungen. Zum Beispiel, wenn Delta_A=0.001und Delta_B=0.02, wird die Verbesserung von Schritt jzu Schritt j+1sein 0.02.
Zhubarb

1
@Zhubarb: Ist das ein Standardansatz für die Konvergenzberechnung in EM, oder haben Sie sich das ausgedacht? Wenn es sich um einen Standardansatz handelt, können Sie bitte eine Referenz angeben?
Stackoverflowuser2010

Hier ist ein Hinweis auf die Konvergenz von EM. Ich habe den Code vor einiger Zeit geschrieben, kann mich also nicht gut erinnern. Ich glaube, was Sie im Code sehen, ist mein Konvergenzkriterium für diesen speziellen Fall. Die Idee ist, Iterationen zu stoppen, wenn das Maximum der Verbesserungen für A und B kleiner als ist delta.
Zhubarb

1
Hervorragend, es gibt nichts
Schöneres

63

Es hört sich so an, als ob Ihre Frage zwei Teile hat: die zugrunde liegende Idee und ein konkretes Beispiel. Ich werde mit der zugrunde liegenden Idee beginnen und dann unten auf ein Beispiel verweisen.


EM ist in Catch-22-Situationen nützlich, in denen es den Anschein hat, als müssten Sie kennen, bevor Sie berechnen können, und Sie müssen kennen, bevor Sie berechnen können .B B AABBA

Der häufigste Fall, mit dem Menschen zu tun haben, sind wahrscheinlich Mischungsverteilungen. Schauen wir uns für unser Beispiel ein einfaches Gaußsches Mischungsmodell an:

Sie haben zwei verschiedene univariate Gaußsche Verteilungen mit unterschiedlichen Mitteln und Einheitenvarianz.

Sie verfügen über eine Reihe von Datenpunkten, sind sich jedoch nicht sicher, welche Punkte von welcher Verteilung stammen, und Sie sind sich auch nicht sicher, welche Mittel die beiden Verteilungen haben.

Und jetzt steckst du fest:

  • Wenn Sie die wahren Mittel kennen, können Sie herausfinden, welche Datenpunkte von welchem ​​Gaußschen stammen. Wenn beispielsweise ein Datenpunkt einen sehr hohen Wert hatte, stammte er wahrscheinlich aus der Verteilung mit dem höheren Mittelwert. Aber Sie wissen nicht, was die Mittel sind, also funktioniert das nicht.

  • Wenn Sie wissen, von welcher Verteilung jeder Punkt stammt, können Sie die Mittelwerte der beiden Verteilungen anhand der Stichprobenmittelwerte der relevanten Punkte schätzen. Sie wissen jedoch nicht genau, welche Punkte Sie welcher Verteilung zuweisen sollen, daher funktioniert dies auch nicht.

Daher scheint keiner der beiden Ansätze zu funktionieren: Sie müssen die Antwort kennen, bevor Sie sie finden können, und Sie stecken fest.

Mit EM können Sie zwischen diesen beiden Schritten wechseln, anstatt den gesamten Prozess auf einmal in Angriff zu nehmen.

Sie müssen mit einer Vermutung über die beiden Mittel beginnen (obwohl Ihre Vermutung nicht unbedingt sehr genau sein muss, müssen Sie irgendwo beginnen).

Wenn Ihre Vermutung über die Mittel zutreffend war, dann hätten Sie genug Informationen, um den Schritt in meinem ersten Aufzählungspunkt oben auszuführen, und Sie könnten (wahrscheinlich) jeden Datenpunkt einem der beiden Gaußschen zuordnen. Auch wenn wir wissen, dass unsere Vermutung falsch ist, versuchen wir es trotzdem. Anhand der zugewiesenen Verteilungen der einzelnen Punkte können Sie dann neue Schätzungen für die Mittelwerte erhalten, die den zweiten Aufzählungspunkt verwenden. Es stellt sich heraus, dass Sie jedes Mal, wenn Sie diese beiden Schritte durchlaufen, eine niedrigere Grenze für die Wahrscheinlichkeit des Modells verbessern.

Das ist schon ziemlich cool: Auch wenn die beiden Vorschläge in den obigen Aufzählungspunkten nicht so aussahen, als würden sie einzeln funktionieren, können Sie sie dennoch zusammen verwenden, um das Modell zu verbessern. Die wahre Magie von EM besteht darin, dass die Untergrenze nach einer ausreichenden Anzahl von Iterationen so hoch ist, dass zwischen ihr und dem lokalen Maximum kein Abstand mehr besteht. Infolgedessen haben Sie die Wahrscheinlichkeit lokal optimiert.

So haben Sie nicht nur verbessern das Modell, haben Sie das gefundenen beste mögliche Modell eines mit inkrementellem Updates finden.


Diese Seite aus Wikipedia zeigt ein etwas komplizierteres Beispiel (zweidimensionale Gaußsche und unbekannte Kovarianz), aber die Grundidee ist dieselbe. Es enthält auch gut kommentierten RCode zur Implementierung des Beispiels.

Im Code entspricht der Schritt "Erwartung" (E-Schritt) meinem ersten Aufzählungspunkt: Herausfinden, welcher Gauß'sche Wert für jeden Datenpunkt verantwortlich ist, wenn die aktuellen Parameter für jeden Gauß'schen Wert gegeben sind. Der "Maximierungs" -Schritt (M-Schritt) aktualisiert die Mittelwerte und Kovarianzen unter Berücksichtigung dieser Zuordnungen wie in meinem zweiten Aufzählungspunkt.

Wie Sie in der Animation sehen können, ermöglichen diese Aktualisierungen dem Algorithmus, schnell von einer Reihe schrecklicher Schätzungen zu einer Reihe sehr guter zu gelangen: Es scheint tatsächlich zwei Punktwolken zu geben, die auf den beiden von EM gefundenen Gaußschen Verteilungen zentriert sind.


13

Hier ist ein Beispiel für die Expectation Maximization (EM), mit der der Mittelwert und die Standardabweichung geschätzt werden. Der Code ist in Python, aber es sollte leicht zu befolgen sein, auch wenn Sie nicht mit der Sprache vertraut sind.

Die Motivation für EM

Die unten gezeigten roten und blauen Punkte stammen aus zwei verschiedenen Normalverteilungen mit jeweils einem bestimmten Mittelwert und einer bestimmten Standardabweichung:

Bildbeschreibung hier eingeben

Um vernünftige Annäherungen der "wahren" Mittel- und Standardabweichungsparameter für die Rotverteilung zu berechnen, könnten wir sehr einfach die roten Punkte betrachten und die Position von jedem aufzeichnen und dann die bekannten Formeln verwenden (und ähnlich für die blaue Gruppe) .

Betrachten Sie nun den Fall, in dem wir wissen, dass es zwei Gruppen von Punkten gibt, wir jedoch nicht sehen können, welcher Punkt zu welcher Gruppe gehört. Mit anderen Worten, die Farben sind versteckt:

Bildbeschreibung hier eingeben

Es ist überhaupt nicht klar, wie man die Punkte in zwei Gruppen aufteilt. Wir können jetzt nicht nur die Positionen betrachten und Schätzungen für die Parameter der Rotverteilung oder der Blauverteilung berechnen.

Hier kann EM zur Lösung des Problems eingesetzt werden.

Verwenden von EM zum Schätzen von Parametern

Hier ist der Code, der zum Generieren der oben gezeigten Punkte verwendet wird. Sie können die tatsächlichen Mittelwerte und Standardabweichungen der Normalverteilungen sehen, aus denen die Punkte gezogen wurden. Die Variablen redund bluehalten die Positionen der einzelnen Punkte in der roten bzw. der blauen Gruppe:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible random results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue)))

Wenn wir die Farbe jedes Punktes sehen könnten , würden wir versuchen, Mittelwerte und Standardabweichungen mithilfe von Bibliotheksfunktionen wiederherzustellen:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Aber da die Farben uns verborgen sind, werden wir den EM-Prozess starten ...

Zuerst raten wir nur die Werte für die Parameter jeder Gruppe ( Schritt 1 ). Diese Vermutungen müssen nicht gut sein:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Bildbeschreibung hier eingeben

Ziemlich schlechte Vermutungen - die Mittelwerte sehen so aus, als wären sie weit von einer "Mitte" einer Gruppe von Punkten entfernt.

Um mit EM fortzufahren und diese Vermutungen zu verbessern, berechnen wir die Wahrscheinlichkeit, dass jeder Datenpunkt (unabhängig von seiner geheimen Farbe) unter diesen Vermutungen für den Mittelwert und die Standardabweichung erscheint ( Schritt 2 ).

Die Variable both_coloursenthält jeden Datenpunkt. Die Funktion stats.normberechnet die Wahrscheinlichkeit des Punktes unter Normalverteilung mit den angegebenen Parametern:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Dies zeigt uns zum Beispiel, dass der Datenpunkt bei 1,761 nach unseren derzeitigen Schätzungen viel wahrscheinlicher rot (0,189) als blau (0,00003) ist.

Wir können diese beiden Wahrscheinlichkeitswerte in Gewichtungen umwandeln ( Schritt 3 ), sodass sie wie folgt zu 1 summieren:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Mit unseren aktuellen Schätzungen und unseren neu berechneten Gewichten können wir jetzt neue, wahrscheinlich bessere Schätzungen für die Parameter berechnen ( Schritt 4 ). Wir brauchen eine Funktion für den Mittelwert und eine Funktion für die Standardabweichung:

def estimate_mean(data, weight):
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

Diese sehen hinsichtlich Mittelwert und Standardabweichung der Daten den üblichen Funktionen sehr ähnlich. Der Unterschied besteht in der Verwendung eines weightParameters, der jedem Datenpunkt eine Gewichtung zuweist.

Diese Gewichtung ist der Schlüssel zu EM. Je größer das Gewicht einer Farbe auf einem Datenpunkt ist, desto stärker beeinflusst der Datenpunkt die nächsten Schätzungen für die Parameter dieser Farbe. Letztendlich hat dies den Effekt, dass jeder Parameter in die richtige Richtung gezogen wird.

Die neuen Vermutungen werden mit diesen Funktionen berechnet:

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

Der EM-Prozess wird dann mit diesen neuen Annahmen ab Schritt 2 wiederholt. Wir können die Schritte für eine bestimmte Anzahl von Iterationen (z. B. 20) wiederholen oder bis die Parameter konvergieren.

Nach fünf Iterationen sehen wir, dass sich unsere anfänglichen Fehleinschätzungen verbessern:

Bildbeschreibung hier eingeben

Nach 20 Iterationen ist der EM-Prozess mehr oder weniger konvergiert:

Bildbeschreibung hier eingeben

Zum Vergleich werden hier die Ergebnisse des EM-Prozesses mit den berechneten Werten verglichen, bei denen die Farbinformationen nicht verborgen sind:

          | EM guess | Actual 
----------+----------+--------
Red mean  |    2.910 |   2.802
Red std   |    0.854 |   0.871
Blue mean |    6.838 |   6.932
Blue std  |    2.227 |   2.195

Hinweis: Diese Antwort wurde von meiner Antwort auf Stack Overflow hier angepasst .


10

In Anlehnung an Zhubarbs Antwort habe ich das EM-Beispiel Do und Batzoglou zum "Münzwurf" in GNU R implementiert. Beachten Sie, dass ich die mleFunktion des stats4Pakets verwende - dies hat mir geholfen, den Zusammenhang zwischen EM und MLE besser zu verstehen.

require("stats4");

## sample data from Do and Batzoglou
ds<-data.frame(heads=c(5,9,8,4,7),n=c(10,10,10,10,10),
    coin=c("B","A","A","B","A"),weight_A=1:5*0)

## "baby likelihood" for a single observation
llf <- function(heads, n, theta) {
  comb <- function(n, x) { #nCr function
    return(factorial(n) / (factorial(x) * factorial(n-x)))
  }
  if (theta<0 || theta >1) { # probabilities should be in [0,1]
    return(-Inf);
  }
  z<-comb(n,heads)* theta^heads * (1-theta)^(n-heads);
  return (log(z))
}

## the "E-M" likelihood function
em <- function(theta_A,theta_B) {
  # expectation step: given current parameters, what is the likelihood
  # an observation is the result of tossing coin A (vs coin B)?
  ds$weight_A <<- by(ds, 1:nrow(ds), function(row) {
    llf_A <- llf(row$heads,row$n, theta_A);
    llf_B <- llf(row$heads,row$n, theta_B);

    return(exp(llf_A)/(exp(llf_A)+exp(llf_B)));
  })

  # maximisation step: given params and weights, calculate likelihood of the sample
  return(- sum(by(ds, 1:nrow(ds), function(row) {
    llf_A <- llf(row$heads,row$n, theta_A);
    llf_B <- llf(row$heads,row$n, theta_B);

    return(row$weight_A*llf_A + (1-row$weight_A)*llf_B);
  })))
}

est<-mle(em,start = list(theta_A=0.6,theta_B=0.5), nobs=NROW(ds))

1
@ user3096626 Können Sie bitte erklären, warum Sie im Maximierungsschritt die Wahrscheinlichkeit einer A-Münze (Zeile $ weight_A) mit einer logarithmischen Wahrscheinlichkeit (llf_A) multiplizieren? Gibt es eine spezielle Regel oder einen Grund, warum wir das tun? Ich meine, man würde nur die Wahrscheinlichkeiten oder loglikehoods multiplizieren, aber nicht hem zusammen mischen. Ich habe auch ein neues Thema eröffnet
Alina


5

Die Antwort von Zhubarb ist großartig, aber leider in Python. Im Folgenden finden Sie eine Java-Implementierung des EM-Algorithmus, der auf demselben Problem ausgeführt wird (siehe Artikel von Do und Batzoglou, 2008). Ich habe einige printf's zur Standardausgabe hinzugefügt, um zu sehen, wie die Parameter konvergieren.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

Java-Code folgt unten:

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
observed observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}

5
% Implementation of the EM (Expectation-Maximization)algorithm example exposed on:
% Motion Segmentation using EM - a short tutorial, Yair Weiss, %http://www.cs.huji.ac.il/~yweiss/emTutorial.pdf
% Juan Andrade, jandrader@yahoo.com

clear all
clc

%% Setup parameters
m1 = 2;                 % slope line 1
m2 = 6;                 % slope line 2
b1 = 3;                 % vertical crossing line 1
b2 = -2;                % vertical crossing line 2
x = [-1:0.1:5];         % x axis values
sigma1 = 1;             % Standard Deviation of Noise added to line 1
sigma2 = 2;             % Standard Deviation of Noise added to line 2

%% Clean lines
l1 = m1*x+b1;           % line 1
l2 = m2*x+b2;           % line 2

%% Adding noise to lines
p1 = l1 + sigma1*randn(size(l1));
p2 = l2 + sigma2*randn(size(l2));

%% showing ideal and noise values
figure,plot(x,l1,'r'),hold,plot(x,l2,'b'), plot(x,p1,'r.'),plot(x,p2,'b.'),grid

%% initial guess
m11(1) = -1;            % slope line 1
m22(1) = 1;             % slope line 2
b11(1) = 2;             % vertical crossing line 1
b22(1) = 2;             % vertical crossing line 2

%% EM algorithm loop
iterations = 10;        % number of iterations (a stop based on a threshold may used too)

for i=1:iterations

    %% expectation step (equations 2 and 3)
    res1 = m11(i)*x + b11(i) - p1;
    res2 = m22(i)*x + b22(i) - p2;
    % line 1
    w1 = (exp((-res1.^2)./sigma1))./((exp((-res1.^2)./sigma1)) + (exp((-res2.^2)./sigma2)));

    % line 2
    w2 = (exp((-res2.^2)./sigma2))./((exp((-res1.^2)./sigma1)) + (exp((-res2.^2)./sigma2)));

    %% maximization step  (equation 4)
    % line 1
    A(1,1) = sum(w1.*(x.^2));
    A(1,2) = sum(w1.*x);
    A(2,1) = sum(w1.*x);
    A(2,2) = sum(w1);
    bb = [sum(w1.*x.*p1) ; sum(w1.*p1)];
    temp = A\bb;
    m11(i+1) = temp(1);
    b11(i+1) = temp(2);

    % line 2
    A(1,1) = sum(w2.*(x.^2));
    A(1,2) = sum(w2.*x);
    A(2,1) = sum(w2.*x);
    A(2,2) = sum(w2);
    bb = [sum(w2.*x.*p2) ; sum(w2.*p2)];
    temp = A\bb;
    m22(i+1) = temp(1);
    b22(i+1) = temp(2);

    %% plotting evolution of results
    l1temp = m11(i+1)*x+b11(i+1);
    l2temp = m22(i+1)*x+b22(i+1);
    figure,plot(x,l1temp,'r'),hold,plot(x,l2temp,'b'), plot(x,p1,'r.'),plot(x,p2,'b.'),grid
end

4
Können Sie dem Rohcode eine Diskussion oder Erklärung hinzufügen? Vielen Lesern wäre es nützlich, zumindest die Sprache zu erwähnen, in der Sie schreiben.
Glen_b

1
@ Glen_b - das ist MatLab. Ich frage mich, wie höflich es ist, jemanden in seiner Antwort ausführlicher mit Anmerkungen zu versehen.
EngrStudent

4

Nun, ich würde Ihnen vorschlagen, ein Buch über R von Maria L Rizzo zu lesen. Eines der Kapitel enthält die Verwendung des EM-Algorithmus mit einem numerischen Beispiel. Ich erinnere mich, dass ich den Code zum besseren Verständnis durchgesehen habe.

Versuchen Sie auch, es zu Beginn aus der Sicht eines Clusters zu betrachten. Arbeiten Sie mit der Hand ein Clustering-Problem aus, bei dem 10 Beobachtungen aus zwei verschiedenen Normaldichten entnommen werden. Dies sollte helfen. Nehmen Sie Hilfe von R :)


2

Nur für den Fall, ich habe eine Ruby- Implementierung des oben erwähnten Münzwurf-Beispiels von Do & Batzoglou geschrieben und es erzeugt genau die gleichen Zahlen wie sie mit den gleichen Eingabeparametern und . θ B = 0,5θA=0.6θB=0.5

# gem install distribution
require 'distribution'

# error bound
EPS = 10**-6

# number of coin tosses
N = 10

# observations
X = [5, 9, 8, 4, 7]

# randomly initialized thetas
theta_a, theta_b = 0.6, 0.5

p [theta_a, theta_b]

loop do
  expectation = X.map do |h|
    like_a = Distribution::Binomial.pdf(h, N, theta_a)
    like_b = Distribution::Binomial.pdf(h, N, theta_b)

    norm_a = like_a / (like_a + like_b)
    norm_b = like_b / (like_a + like_b)

    [norm_a, norm_b, h]
  end

  maximization = expectation.each_with_object([0.0, 0.0, 0.0, 0.0]) do |(norm_a, norm_b, h), r|
    r[0] += norm_a * h; r[1] += norm_a * (N - h)
    r[2] += norm_b * h; r[3] += norm_b * (N - h)
  end

  theta_a_hat = maximization[0] / (maximization[0] + maximization[1])
  theta_b_hat = maximization[2] / (maximization[2] + maximization[3])

  error_a = (theta_a_hat - theta_a).abs / theta_a
  error_b = (theta_b_hat - theta_b).abs / theta_b

  theta_a, theta_b = theta_a_hat, theta_b_hat

  p [theta_a, theta_b]

  break if error_a < EPS && error_b < EPS
end
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.