Bedeutung von buffer_sizeinshuffle()
Ich wollte die vorherige Antwort von @mrry weiterverfolgen, um die Wichtigkeit von buffer_sizein hervorzuheben tf.data.Dataset.shuffle().
Ein niedriges Niveau führt in einigen Fällen buffer_sizenicht nur zu einem minderwertigen Mischen, sondern kann auch Ihr gesamtes Training durcheinander bringen.
Ein praktisches Beispiel: Katzenklassifikator
Angenommen, Sie trainieren einen Katzenklassifikator für Bilder und Ihre Daten sind folgendermaßen organisiert (mit 10000Bildern in jeder Kategorie):
train/
cat/
filename_00001.jpg
filename_00002.jpg
...
not_cat/
filename_10001.jpg
filename_10002.jpg
...
Eine Standardmethode zur Dateneingabe mit tf.data kann darin bestehen, eine Liste mit Dateinamen und eine Liste der entsprechenden Beschriftungen tf.data.Dataset.from_tensor_slices()zu erstellen und den Datensatz zu erstellen:
filenames = ["filename_00001.jpg", "filename_00002.jpg", ...,
"filename_10001.jpg", "filename_10002.jpg", ...]
labels = [1, 1, ..., 0, 0...] # 1 for cat, 0 for not_cat
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=1000) # 1000 should be enough right?
dataset = dataset.map(...) # transform to images, preprocess, repeat, batch...
Das große Problem mit dem obigen Code ist, dass der Datensatz tatsächlich nicht richtig gemischt wird. In der ersten Hälfte einer Epoche werden nur Katzenbilder und in der zweiten Hälfte nur Nicht-Katzenbilder angezeigt. Dies wird das Training sehr verletzen.
Zu Beginn des Trainings nimmt der Datensatz die ersten 1000Dateinamen und legt sie in seinen Puffer. Wählen Sie dann zufällig einen aus. Da alle ersten 1000Bilder Bilder von Katzen sind, werden wir nur zu Beginn Katzenbilder auswählen.
Das Update hier ist, um sicherzustellen, dass buffer_sizegrößer als ist20000 , oder im Voraus zu mischen filenamesund labels(offensichtlich mit denselben Indizes).
Da das Speichern aller Dateinamen und Beschriftungen im Speicher kein Problem darstellt, können wir tatsächlich buffer_size = len(filenames)sicherstellen, dass alles zusammengemischt wird. Stellen Sie sicher, dass Sie anrufen, tf.data.Dataset.shuffle()bevor Sie die umfangreichen Transformationen anwenden (z. B. Lesen der Bilder, Verarbeiten, Stapeln ...).
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=len(filenames))
dataset = dataset.map(...) # transform to images, preprocess, repeat, batch...
Zum Mitnehmen muss immer überprüft werden, was das Mischen bewirkt. Eine gute Möglichkeit, diese Fehler zu erkennen, besteht darin, die Verteilung der Chargen über die Zeit zu zeichnen (stellen Sie sicher, dass die Chargen ungefähr die gleiche Verteilung wie das Trainingsset enthalten, in unserem Beispiel halb Katze und halb Nichtkatze).