Welche Parameter sollten zum vorzeitigen Stoppen verwendet werden?


97

Ich trainiere ein neuronales Netzwerk für mein Projekt mit Keras. Keras hat eine Funktion zum frühen Stoppen bereitgestellt. Darf ich wissen, welche Parameter beachtet werden sollten, um eine Überanpassung meines neuronalen Netzwerks durch frühzeitiges Anhalten zu vermeiden?

Antworten:


156

frühes Anhalten

Frühes Anhalten bedeutet im Grunde, das Training abzubrechen, sobald Ihr Verlust zuzunehmen beginnt (oder mit anderen Worten, die Validierungsgenauigkeit abnimmt). Nach Unterlagen wird es wie folgt verwendet;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

Die Werte hängen von Ihrer Implementierung ab (Problem, Stapelgröße usw.), aber im Allgemeinen würde ich eine Überanpassung verwenden, um eine Überanpassung zu verhindern.

  1. Überwachen Sie den Validierungsverlust (Sie müssen eine Kreuzvalidierung oder mindestens Zug- / Testsätze verwenden), indem Sie das monitor Argument auf setzen 'val_loss'.
  2. min_deltaist eine Schwelle dafür, ob ein Verlust in einer bestimmten Epoche als Verbesserung quantifiziert wird oder nicht. Wenn die Verlustdifferenz geringer ist min_delta, wird sie als keine Verbesserung quantifiziert. Es ist besser, es als 0 zu belassen, da wir daran interessiert sind, wann der Verlust schlimmer wird.
  3. patienceDas Argument gibt die Anzahl der Epochen vor dem Stoppen an, sobald Ihr Verlust zu steigen beginnt (hört auf, sich zu verbessern). Dies hängt von Ihrer Implementierung ab. Wenn Sie sehr kleine Stapel oder eine große Lernrate verwenden, ist Ihr Verlust im Zick-Zack (die Genauigkeit wird lauter), also setzen Sie besser ein großes patienceArgument. Wenn Sie große Chargen und eine kleine Lernrate verwenden Ihr Verlust gleichmäßiger, sodass Sie ein kleineres patienceArgument verwenden können. In beiden Fällen lasse ich es als 2, damit ich dem Modell mehr Chancen geben kann.
  4. verbose entscheidet, was gedruckt werden soll, belassen Sie die Standardeinstellung (0).
  5. modeDas Argument hängt davon ab, in welche Richtung Ihre überwachte Menge geht (soll sie abnehmen oder zunehmen), da wir den Verlust überwachen, den wir verwenden können min. Aber lassen wir Keras das für uns erledigen und setzen das aufauto

Also würde ich so etwas verwenden und experimentieren, indem ich den Fehlerverlust mit und ohne vorzeitiges Anhalten aufzeichne.

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=2,
                              verbose=0, mode='auto')

Im Hinblick auf mögliche Unklarheiten bei der Funktionsweise von Rückrufen werde ich versuchen, mehr zu erklären. Sobald Sie fit(... callbacks=[es])Ihr Modell aufrufen , ruft Keras vorgegebene Funktionen für bestimmte Rückrufobjekte auf. Diese Funktionen können aufgerufen werden on_train_begin, on_train_end, on_epoch_begin, on_epoch_endund on_batch_begin,on_batch_end . Ein früher Stopp-Rückruf wird an jedem Epochenende aufgerufen, vergleicht den am besten überwachten Wert mit dem aktuellen und stoppt, wenn die Bedingungen erfüllt sind (wie viele Epochen sind seit der Beobachtung des am besten überwachten Werts vergangen und ist es mehr als ein Geduldargument, der Unterschied zwischen letzter Wert ist größer als min_delta etc ..).

Wie von @BrentFaust in den Kommentaren angegeben, wird das Training des Modells fortgesetzt, bis entweder die Bedingungen für das frühzeitige Anhalten erfüllt sind oder der epochsParameter (Standard = 10) in fit()erfüllt ist. Durch das Festlegen eines Rückrufs zum vorzeitigen Stoppen wird das Modell nicht über seinen epochsParameter hinaus trainiert . Das Aufrufen einer fit()Funktion mit einem größeren epochsWert würde also mehr vom Frühstopp-Rückruf profitieren.


3
@AizuddinAzman close min_deltaist ein Schwellenwert dafür, ob die Änderung des überwachten Werts als Verbesserung quantifiziert wird oder nicht. Also ja, wenn wir geben, monitor = 'val_loss'dann würde es sich auf den Unterschied zwischen dem aktuellen Validierungsverlust und dem vorherigen Validierungsverlust beziehen. In der Praxis würde min_delta=0.1eine Verringerung des Validierungsverlusts (aktuell - vorhergehend) von weniger als 0,1 nicht quantifizieren und somit das Training beenden (falls vorhanden patience = 0).
Umutto

3
Beachten Sie, dass dies callbacks=[EarlyStopping(patience=2)]keine Auswirkung hat, es sei denn, Epochen sind gegeben model.fit(..., epochs=max_epochs).
Brent Faust

1
@BrentFaust Auch das verstehe ich. Ich habe die Antwort unter der Annahme geschrieben, dass das Modell mit mindestens 10 Epochen trainiert wird (standardmäßig). Nach Ihrem Kommentar habe ich festgestellt, dass es möglicherweise einen Fall gibt, mit dem der Programmierer fit epoch=1in einer for-Schleife (für verschiedene Anwendungsfälle) aufruft, in dem dieser Rückruf fehlschlagen würde. Wenn meine Antwort mehrdeutig ist, werde ich versuchen, sie besser auszudrücken.
Umutto

4
@AdmiralWen Seit ich die Antwort geschrieben habe, hat sich der Code etwas geändert. Wenn Sie die neueste Version von Keras verwenden, können Sie das restore_best_weightsArgument (noch nicht in der Dokumentation) verwenden, mit dem das Modell nach dem Training mit den besten Gewichten geladen wird. Aber für Ihre Zwecke würde ich ModelCheckpointRückruf mit save_best_onlyArgument verwenden. Sie können die Dokumentation überprüfen. Die Verwendung ist unkompliziert, Sie müssen jedoch nach dem Training die besten Gewichte manuell laden.
Umutto

1
@umutto Hallo, danke für den Vorschlag von restore_best_weights, aber ich kann ihn nicht verwenden. `es = EarlyStopping (monitor = 'val_acc', min_delta = 1e-4, geduld = geduld_, verbose = 1, restore_best_weights = True) TypeError: __init __ () hat ein unerwartetes Schlüsselwortargument 'restore_best_weights'` erhalten. Irgendwelche Ideen? keras 2.2.2, tf, 1.10 was ist deine version?
Haramoz
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.