Wie trainiere ich Daten stapelweise von der Festplatte?


8

Ich arbeite an einem Faltungsnetzwerk zur Bildklassifizierung. Der Trainingsdatensatz ist zu groß, um auf meinen Computerspeicher geladen zu werden (4 GB). Außerdem muss ich eine Erweiterung versuchen, um die Klassen auszugleichen.

Ich benutze keras. Ich habe viele Beispiele untersucht, aber keine Lösung für mein Problem gefunden. Ich denke darüber nach, die Parameter des model.fitAufrufs durch Übergeben einer Funktion oder eines Generators zu ersetzen, der das 'Batching' ausführt. Diese Funktion werde ich entwerfen, um sie von der Festplatte zu importieren und eine Erweiterung anzuwenden.

Ich habe keine Ahnung, wie ich das umsetzen soll. Irgendwelche Vorschläge?


Können Sie mir zeigen, wie diese Zugfunktionen und die entsprechenden Zugfasern auf der Festplatte gespeichert werden, damit ich sie beim nächsten Mal nicht mehr berechnen muss?
Supreeth Ys

Antworten:


8

Während Sie an der Bildklassifizierung arbeiten und auch eine Datenerweiterung implementieren möchten, können Sie die beiden kombinieren UND die Stapel mithilfe der mächtigen 'ImageDataGenerator'-Klasse direkt aus einem Ordner laden.

Schauen Sie sich die ausführliche Dokumentation an!

Ich werde das Beispiel von diesem Link nicht kopieren und einfügen, aber ich kann die Schritte skizzieren, die Sie durchlaufen:

  1. Erstellen Sie die Generatorklasse: data_gen = ImageDataGenerator()

  2. Wenn Sie möchten, dass die Erweiterung im laufenden Betrieb für Sie durchgeführt wird, kann dies beim Erstellen der Klasse angegeben werden: data_gen = ImageDataGenerator(samplewise_center=True, ...)

  3. Wenn Sie Erweiterungsprozesse verwenden, für die einige Statistiken zum Dataset erforderlich sind, z. B. eine merkmalweise Normalisierung (nicht stichprobenweise), müssen Sie den Generator vorbereiten, indem Sie einige Ihrer Daten anzeigen : data_gen.fit(training_data). Diese fitMethode berechnet einfach Dinge wie den Mittelwert und die Standardabweichung vor, die später zur Normalisierung verwendet werden.

  4. Der Generator geht in die fit_generatorMethode des Modells über , und wir nennen die flow_from_directoryMethode des Generators:

    model.fit_generator(training_data=data_gen.flow_from_directory('/path/to/folder/'), ...)

Sie können mit ImageDataGenerator auch einen separaten Generator für Ihre Validierungsdaten erstellen, in dem Sie die Erweiterung dann nicht anwenden sollten, damit Validierungstests an realen Daten durchgeführt werden, um ein genaues Bild der Modellleistung zu erhalten.

In jedem Fall laufen diese Generatoren theoretisch für immer und generieren Stapel aus Ihrem Ordner. Daher empfehle ich die Verwendung einer Rückruffunktion von Keras, um zu stoppen, wenn bestimmte Kriterien erfüllt sind. Weitere Informationen finden Sie in der Dokumentation zur EarlyStopping-Klasse . Sie können dies auch manuell tun, aber Keras macht es sehr einfach!

Wenn Sie eine feinkörnige Steuerung wünschen, können Sie alle oben genannten Schritte manuell ausführen, indem Sie genügend Proben von der Festplatte für einen einzelnen Stapel laden, eine Erweiterung durchführen und dann die model.train_on_batch()Methode ausführen . Wenn Sie sich mit den Details befassen möchten, lernen Sie am besten zuerst die Keras-Methode und fahren dann mit Ihren eigenen detaillierten Modellen fort, die Tensorflow nach Bedarf kämmen. Die beiden können sehr gut zusammen verwendet werden!


Ich habe den 3. Schritt nicht verstanden. Können Sie mir sagen, dass im 3. Schritt training_data den gesamten Datensatz oder einige Daten aus dem Datensatz enthalten soll?
Junaid

@Junaid - Es kann ein kleiner Teil der Daten sein. Es muss jedoch ausreichen, einen vernünftigen Wert zu berechnen, z. B. den Mittelwert, der dann während des Trainings für den gesamten Datensatz verwendet wird. Ich habe Punkt 3 ein wenig mehr Informationen hinzugefügt. Hier ist ein Link zu einem Beispiel der flow_from_directoryMethode
n1k31t4
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.