WARNUNG: Tensorflow: Die Modi sample_weight wurden von… nach ['…'] gezwungen.


47

Trainieren eines Bildklassifikators unter Verwendung .fit_generator()oder .fit()Übergeben eines Wörterbuchs class_weight=als Argument.

Ich habe in TF1.x nie Fehler bekommen, aber in 2.1 bekomme ich zu Beginn des Trainings die folgende Ausgabe:

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']

Was bedeutet es, etwas von ...bis zu erzwingen ['...']?

Die Quelle für diese Warnung auf tensorflowdem Repo ist hier . Kommentare sind:

Versuchen Sie, sample_weight_modes zur Zielstruktur zu zwingen. Dies hängt implizit davon ab, dass das Modell die Ausgaben für seine interne Darstellung abflacht.


7
Es ist lustig, eine so aktuelle Frage als einziges Suchergebnis für meine eigenen Warnungen zu sehen.
jmkjaer

1
@jorijnsmit Können Sie den Code angeben, um das Problem / die Warnung zu replizieren?
Thushv89

2
Der tatsächliche Wechsel zu TF2 mit %tensorflow_version 2.xreicht aus, um diese Warnung anzuzeigen
jorijnsmit

1
@jorijnsmit, Nein, ich bekomme die gleiche Warnung, habe aber tatsächlich TF2.1 installiert wie pip install tensorflow(innerhalb der pyenv / virtualenv-Umgebung)
lurix66

1
Ja, in der Tat, @ lurix66, der Code, der diesen Fehler generiert, wird in eingeführt 2.1.0rc0.
17.

Antworten:


11

Dies scheint eine falsche Nachricht zu sein. Nach dem Upgrade auf TensorFlow 2.1 wird dieselbe Warnmeldung angezeigt, ich verwende jedoch überhaupt keine Klassen- oder Stichprobengewichte. Ich benutze einen Generator, der ein Tupel wie dieses zurückgibt:

return inputs, targets

Und jetzt habe ich es einfach wie folgt geändert, damit die Warnung verschwindet:

return inputs, targets, [None]

Ich weiß nicht, ob dies relevant ist, aber mein Modell verwendet 3 Eingaben, sodass meine inputsVariable tatsächlich eine Liste von 3 Numpy-Arrays ist. targetsist nur ein einzelnes Numpy-Array.

In jedem Fall ist es nur eine Warnung. Das Training funktioniert so oder so gut.

Für TensorFlow 2.2 bearbeiten:

Dieser Fehler scheint in TensorFlow 2.2 behoben worden zu sein, was großartig ist. Die obige Korrektur schlägt jedoch in TF 2.2 fehl, da versucht wird, die Form der Probengewichte zu ermitteln, was offensichtlich fehlschlägt AttributeError: 'NoneType' object has no attribute 'shape'. Machen Sie das obige Update rückgängig, wenn Sie auf 2.2 aktualisieren.


Das funktioniert auch bei mir.
Robert Lugg

14

Ich glaube, dies ist ein Fehler mit Tensorflow, der auftritt, wenn Sie model.compile()mit Standardparametern sample_weight_mode=Noneaufrufen und dann model.fit()mit angegebenem sample_weightoder aufrufen class_weight.

Aus den Tensorflow-Repos:

  • fit() ruft schließlich an _process_training_inputs()
  • _process_training_inputs() setzt sample_weight_modes = [None] basierend auf model.sample_weight_mode = Noneund erstellt dann ein DataAdaptermitsample_weight_modes = [None]
  • die DataAdapterAnrufe broadcast_sample_weight_modes()mit sample_weight_modes = [None]während der Initialisierung
  • broadcast_sample_weight_modes() scheint zu erwarten sample_weight_modes = None , erhält aber[None]
  • Es wird behauptet, dass [None]es sich um eine andere Struktur als sample_weight/ handelt class_weight, es wird Nonedurch Anpassen an die Struktur von sample_weight/ zurückgeschrieben class_weightund eine Warnung ausgegeben

Warnung beiseite dies hat keine Auswirkung auf fit()wie sample_weight_modesin der zurückgesetzt DataAdapterwird None.

Beachten Sie, dass in der Tensorflow- Dokumentation angegeben ist , dass sample_weightes sich um ein Numpy-Array handeln muss. Wenn Sie stattdessen fit()mit anrufen sample_weight.tolist(), erhalten Sie keine Warnung, sondern sample_weightwerden stillschweigend überschrieben, Nonewenn _process_numpy_inputs()sie in der Vorverarbeitung aufgerufen werden und eine Eingabe mit einer Länge von mehr als eins erhalten.


1
Eine sehr gründliche Erklärung, danke. Das einzige, was ich nicht verstehe, ist, dass die Warnung beschreibt ..., gezwungen zu werden [...], während in Ihrem Fall [None]gezwungen wird, None...
jorijnsmit

4

Ich habe Ihr Gist genommen und Tensorflow 2.0 anstelle von TFA installiert, und es hat ohne eine solche Warnung funktioniert.

Hier ist der Kern des vollständigen Codes. Der Code für die Installation des Tensorflow ist unten dargestellt:

!pip install tensorflow==2.0

Der Screenshot der erfolgreichen Ausführung ist unten dargestellt:

Geben Sie hier die Bildbeschreibung ein

Update: Dieser Fehler wurde behobenTensorflow Version 2.2.


5
Danke für Ihre Antwort. Sie haben Recht, die Warnmeldung wird erst in der Version eingeführt 2.1.0rc0. Ich fürchte jedoch, meine Frage bleibt: "Was bedeutet es, etwas von ...bis zu erzwingen ['...']?"
Jorijnsmit

3
Mir ist aufgefallen, dass einige wahrscheinlich unbeabsichtigte Dinge passieren, wenn sample_weight_mode=Noneund target_structureist vom Typ dict, sample_weight_modesdann [None]und die Ausnahme in broadcast_sample_weight_modeswird aufgrund der gefangen dict. Kann dies als Fehler angesehen werden?
Franz Knülle

2
Nee. Die Frage sammelt immer wieder Ansichten und Stimmen, aber keine Antworten.
jorijnsmit

1
@gkennos: Wenn Sie der Meinung sind, dass es sich um einen Fehler handelt, können Sie einen Fehler im Github Tensorflow Repository einreichen.
Tensorflow-Unterstützung

1
Es ist definitiv ein Fehler, aber es ist jetzt in TensorFlow 2.2
23.

2

anstatt ein Wörterbuch bereitzustellen

weights = {'0': 42.0, '1': 1.0}

Ich habe eine Liste ausprobiert

weights = [42.0, 1.0]

und die Warnung verschwand.


Danke, Mann! Ich habe versucht (erfolglos) mit Wörterbüchern. Durch die Verwendung der Liste wird der Fehler behoben!
Victor Mondejar-Guerra

Während dies den Fehler beseitigt, führt dies für mich dazu, dass die Gewichtung für jede Klasse schlechtere Ergebnisse liefert. Ich würde die Konsistenz überprüfen, bevor ich zu einer Liste wechsle.
CanofDrink
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.