Ich verwende Tensorflow, schreibe jedoch eine Dokumentation für Benutzer, die in der Regel je nach Deep-Learning-Framework variiert .
Wenn ich mit Datensätzen arbeite, die nicht in das lokale Dateisystem (TB +) passen, probiere ich Daten aus einem entfernten Datenspeicher und schreibe Proben lokal in ein Tensorflow-Standardformat tfrecords
.
Während der ersten Trainingsepoche habe ich nur einige Werte abgetastet, daher ist eine Epoche lokaler Daten sehr klein, ich trainiere darauf. In Epoche 2 überprüfe ich erneut, welche Datendateien von meinen Stichproben-Teilprozessen erzeugt wurden (jetzt mehr), und trainiere den erweiterten Satz lokaler Datendateien für die nächste Epoche. Wiederholen Sie den Vorgang in jeder Epoche. Auf diese Weise baue ich einen lokalen Cache mit Samples auf und kann ältere Samples entfernen, wenn ich den lokalen Speicher auffülle. Der lokale Stichproben-Cache wächst ungefähr zu dem Zeitpunkt, an dem das Modell die Varianz am meisten benötigt (in Richtung des letzten Teils des Trainings).
In Python / Tensorflow ist es entscheidend, dass ich die Daten im Python-Trainingsschleifenprozess nicht deserialisiere, da die Python-GIL die Datenübertragungsraten (300-600 MB / s, die Daten sind wissenschaftlich unkomprimierbar) und damit die GPU-Leistung nicht unterstützen kann leidet, wenn die Python GIL die Trainingsschleife nicht schnell bedienen kann.
Durch das Schreiben der Samples in eine tfrecords
Datei aus Unterprozessen (Python-Multiprocessing) kann der native Tensorflow eine TFRecordsDataset
Deserialisierung außerhalb von Python durchführen. Daher umgehen wir die Python-GIL-Probleme und können eine GPU mit hohen E / A-Datenraten sättigen.
Ich würde gerne wissen, wie ich dieses Problem in Pytorch angehen würde. Ich schreibe über die verwendete Stichprobenstrategie und möchte Benutzern von Tensorflow und PyTorch spezifische Empfehlungen geben, aber ich kenne das PyTorch-Vorverarbeitungs-Ökosystem nicht gut genug, um mit ausreichenden Details zu schreiben.
Randnotiz: Die einzige rein Python-basierte Lösung zur Unterstützung dieser Datenübertragungsraten ist möglicherweise Python 3.8 mit gemeinsam genutztem System V-Speicher und Multiprocessing. Ich habe dies jedoch noch nicht versucht, da die Unterstützung dafür nicht ausreicht (bald wird es so sein) ). Bestehende Multiprozessor-Lösungen reichen nicht aus, da sie eine Deserialisierung im Trainingsschleifenprozess erfordern und somit die GIL während der Deserialisierung mit hohen E / A-Raten sperren.
DataLoader
wie in meiner Antwort geladen werden .