Bedeutung von buffer_size
inshuffle()
Ich wollte die vorherige Antwort von @mrry weiterverfolgen, um die Wichtigkeit von buffer_size
in hervorzuheben tf.data.Dataset.shuffle()
.
Ein niedriges Niveau führt in einigen Fällen buffer_size
nicht nur zu einem minderwertigen Mischen, sondern kann auch Ihr gesamtes Training durcheinander bringen.
Ein praktisches Beispiel: Katzenklassifikator
Angenommen, Sie trainieren einen Katzenklassifikator für Bilder und Ihre Daten sind folgendermaßen organisiert (mit 10000
Bildern in jeder Kategorie):
train/
cat/
filename_00001.jpg
filename_00002.jpg
...
not_cat/
filename_10001.jpg
filename_10002.jpg
...
Eine Standardmethode zur Dateneingabe mit tf.data
kann darin bestehen, eine Liste mit Dateinamen und eine Liste der entsprechenden Beschriftungen tf.data.Dataset.from_tensor_slices()
zu erstellen und den Datensatz zu erstellen:
filenames = ["filename_00001.jpg", "filename_00002.jpg", ...,
"filename_10001.jpg", "filename_10002.jpg", ...]
labels = [1, 1, ..., 0, 0...] # 1 for cat, 0 for not_cat
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=1000) # 1000 should be enough right?
dataset = dataset.map(...) # transform to images, preprocess, repeat, batch...
Das große Problem mit dem obigen Code ist, dass der Datensatz tatsächlich nicht richtig gemischt wird. In der ersten Hälfte einer Epoche werden nur Katzenbilder und in der zweiten Hälfte nur Nicht-Katzenbilder angezeigt. Dies wird das Training sehr verletzen.
Zu Beginn des Trainings nimmt der Datensatz die ersten 1000
Dateinamen und legt sie in seinen Puffer. Wählen Sie dann zufällig einen aus. Da alle ersten 1000
Bilder Bilder von Katzen sind, werden wir nur zu Beginn Katzenbilder auswählen.
Das Update hier ist, um sicherzustellen, dass buffer_size
größer als ist20000
, oder im Voraus zu mischen filenames
und labels
(offensichtlich mit denselben Indizes).
Da das Speichern aller Dateinamen und Beschriftungen im Speicher kein Problem darstellt, können wir tatsächlich buffer_size = len(filenames)
sicherstellen, dass alles zusammengemischt wird. Stellen Sie sicher, dass Sie anrufen, tf.data.Dataset.shuffle()
bevor Sie die umfangreichen Transformationen anwenden (z. B. Lesen der Bilder, Verarbeiten, Stapeln ...).
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=len(filenames))
dataset = dataset.map(...) # transform to images, preprocess, repeat, batch...
Zum Mitnehmen muss immer überprüft werden, was das Mischen bewirkt. Eine gute Möglichkeit, diese Fehler zu erkennen, besteht darin, die Verteilung der Chargen über die Zeit zu zeichnen (stellen Sie sicher, dass die Chargen ungefähr die gleiche Verteilung wie das Trainingsset enthalten, in unserem Beispiel halb Katze und halb Nichtkatze).