Wie erreicht ein einfaches logistisches Regressionsmodell eine Klassifizierungsgenauigkeit von 92% für MNIST?


64

Obwohl alle Bilder im MNIST-Datensatz in einem ähnlichen Maßstab zentriert und ohne Rotation sichtbar sind, weisen sie eine signifikante Variation der Handschrift auf, die mich verblüfft, wie ein lineares Modell eine so hohe Klassifizierungsgenauigkeit erzielt.

Soweit ich in der Lage bin, angesichts der signifikanten Variation der Handschrift zu visualisieren, sollten die Ziffern in einem 784-dimensionalen Raum linear untrennbar sein, dh es sollte eine kleine komplexe (wenn auch nicht sehr komplexe) nichtlineare Grenze geben, die die verschiedenen Ziffern voneinander trennt ähnlich das gut zitierte XOR Beispiel , in dem positiven und negativen Klassen können nicht durch eine lineare Klassifizierer getrennt werden. Es scheint mir verwirrend, wie die logistische Regression mehrerer Klassen eine so hohe Genauigkeit mit vollständig linearen Merkmalen (keine Polynommerkmale) erzeugt.

Beispielsweise können bei einem beliebigen Pixel im Bild verschiedene handschriftliche Variationen der Ziffern 2 und 3 bewirken, dass dieses Pixel beleuchtet wird oder nicht. Daher kann mit einem Satz von erlernten Gewichten jedes Pixel eine Ziffer sowohl als 2 als auch als 3 aussehen lassen . Nur mit einer Kombination von Pixelwerten sollte es möglich sein zu sagen, ob eine Ziffer eine 2 oder eine 3 . Dies gilt für die meisten Ziffernpaare. Wie kann also die logistische Regression, die ihre Entscheidung blind auf alle Pixelwerte stützt (ohne Berücksichtigung von Abhängigkeiten zwischen Pixeln), so hohe Genauigkeiten erzielen?

Ich weiß, dass ich irgendwo falsch liege oder die Abweichungen in den Bildern einfach zu hoch einschätze. Es wäre jedoch großartig, wenn mir jemand mit einer Intuition helfen könnte, wie die Ziffern "fast" linear trennbar sind.


Schauen Sie sich das Lehrbuch Statistisches Lernen mit Sparsity an: das Lasso und Verallgemeinerungen 3.3.1 Beispiel: Handschriftliche Ziffern web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

Ich war neugierig: Wie gut kann so etwas wie ein bestraftes lineares Modell (dh glmnet) das Problem lösen? Wenn ich mich recht entsinne, berichten Sie von der Genauigkeit, mit der die Stichprobe nicht bestraft wurde.
Cliff AB

Antworten:


82

tl; dr Auch wenn dies ist eine Bild Klassifizierung Dataset, bleibt es eine sehr einfache Aufgabe, für die man leicht einen finden kann direkte Zuordnung von Eingängen zu Prognosen.


Antworten:

Dies ist eine sehr interessante Frage, und dank der Einfachheit der logistischen Regression können Sie die Antwort tatsächlich herausfinden.

78478428×28

Beachten Sie erneut, dass dies die Gewichte sind .

Schauen Sie sich nun das obige Bild an und konzentrieren Sie sich auf die ersten beiden Ziffern (dh Null und Eins). Blaue Gewichte bedeuten, dass die Intensität dieses Pixels einen großen Beitrag für diese Klasse leistet, und rote Werte bedeuten, dass sie einen negativen Beitrag leistet.

0

1

2378

Hierdurch können Sie sehen, dass die logistische Regression eine sehr gute Chance hat, viele Bilder richtig zu machen, und das ist der Grund, warum sie so gut abschneidet.


Der Code zum Reproduzieren der obigen Abbildung ist etwas veraltet, aber hier ist es:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

9
2378

13
Natürlich hilft es, dass MNIST-Samples zentriert, skaliert und kontrastnormalisiert werden, bevor der Klassifikator sie jemals sieht. Sie müssen sich nicht mit Fragen wie "Was ist, wenn die Kante der Null tatsächlich durch die Mitte der Box geht?" weil der Pre-Prozessor schon viel getan hat, um alle Nullen gleich aussehen zu lassen.
Hobbs

1
@EricDuminil Ich habe dem Skript einen Kommentar mit Ihrem Vorschlag hinzugefügt. Vielen Dank für die Eingabe! : D
Djib2011

1
@ NitishAgarwal, Wenn Sie der Meinung sind, dass diese Antwort die Antwort auf Ihre Frage ist, sollten Sie sie als solche markieren.
Sintax

7
Für jemanden, der an dieser Art der Verarbeitung interessiert, aber nicht besonders vertraut ist, bietet diese Antwort ein fantastisches intuitives Beispiel für die Mechanik.
chrylis -on strike-
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.