TensorFlow Speichern / Laden eines Diagramms aus einer Datei


98

Nach dem, was ich bisher gesammelt habe, gibt es verschiedene Möglichkeiten, ein TensorFlow-Diagramm in eine Datei zu kopieren und dann in ein anderes Programm zu laden, aber ich konnte keine eindeutigen Beispiele / Informationen zu deren Funktionsweise finden. Was ich bereits weiß, ist Folgendes:

  1. Speichern Sie die Variablen des Modells mit a in einer Prüfpunktdatei (.ckpt) tf.train.Saver()und stellen Sie sie später wieder her ( Quelle ).
  2. Speichern Sie ein Modell in einer .pb-Datei und laden Sie es mit tf.train.write_graph()und tf.import_graph_def()( Quelle ) zurück.
  3. Laden Sie ein Modell aus einer .pb-Datei, trainieren Sie es erneut und speichern Sie es mit Bazel ( Quelle ) in einer neuen .pb-Datei.
  4. Frieren Sie das Diagramm ein, um das Diagramm und die Gewichte zusammen zu speichern ( Quelle )
  5. Verwenden Sie as_graph_def()diese Option , um das Modell zu speichern und für Gewichte / Variablen Konstanten zuzuordnen ( Quelle ).

Ich konnte jedoch einige Fragen zu diesen verschiedenen Methoden nicht klären:

  1. Speichern Checkpoint-Dateien nur die trainierten Gewichte eines Modells? Könnten Checkpoint-Dateien in ein neues Programm geladen und zum Ausführen des Modells verwendet werden, oder dienen sie einfach dazu, die Gewichte in einem Modell zu einem bestimmten Zeitpunkt / in einem bestimmten Stadium zu speichern?
  2. Werden tf.train.write_graph()auch die Gewichte / Variablen gespeichert?
  3. Kann Bazel nur zur Umschulung in .pb-Dateien gespeichert oder daraus geladen werden? Gibt es einen einfachen Bazel-Befehl, um ein Diagramm in eine .pb-Datei zu kopieren?
  4. Kann beim Einfrieren ein eingefrorenes Diagramm mit verwendet werden tf.import_graph_def()?
  5. Die Android-Demo für TensorFlow wird im Google Inception-Modell aus einer .pb-Datei geladen. Wenn ich meine eigene .pb-Datei ersetzen wollte, wie würde ich das tun? Muss ich nativen Code / native Methoden ändern?
  6. Was genau ist im Allgemeinen der Unterschied zwischen all diesen Methoden? Oder allgemeiner, was ist der Unterschied zwischen as_graph_def()/.ckpt/.pb?

Kurz gesagt, ich suche nach einer Methode, um sowohl ein Diagramm (wie in den verschiedenen Operationen und dergleichen) als auch seine Gewichte / Variablen in einer Datei zu speichern, die dann zum Laden des Diagramms und der Gewichte in ein anderes Programm verwendet werden kann , zur Verwendung (nicht unbedingt Fortsetzung / Umschulung).

Die Dokumentation zu diesem Thema ist nicht sehr einfach, daher wären Antworten / Informationen sehr willkommen.


2
Die neueste / vollständigste API ist das Metadiagramm, mit dem Sie alle drei gleichzeitig speichern können - 1) Diagramm 2) Parameterwerte 3) Sammlungen: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Jaroslaw Bulatow

Antworten:


80

Es gibt viele Möglichkeiten, das Problem des Speicherns eines Modells in TensorFlow anzugehen, was es etwas verwirrend machen kann. Nehmen Sie jede Ihrer Unterfragen der Reihe nach:

  1. Die Prüfpunktdateien (die beispielsweise durch Aufrufen saver.save()eines tf.train.SaverObjekts erstellt wurden) enthalten nur die Gewichte und alle anderen Variablen, die im selben Programm definiert sind. Um sie in einem anderen Programm zu verwenden, müssen Sie die zugehörige Diagrammstruktur neu erstellen (z. B. indem Sie Code ausführen, um sie erneut zu erstellen, oder aufrufen tf.import_graph_def()), um TensorFlow mitzuteilen, was mit diesen Gewichten zu tun ist. Beachten Sie, dass beim Aufrufen saver.save()auch eine Datei mit a erstellt wird MetaGraphDef, die ein Diagramm und Details zum Zuordnen der Gewichte von einem Prüfpunkt zu diesem Diagramm enthält. Weitere Informationen finden Sie im Tutorial .

  2. tf.train.write_graph()schreibt nur die Graphstruktur; nicht die Gewichte.

  3. Bazel hat nichts mit dem Lesen oder Schreiben von TensorFlow-Diagrammen zu tun. (Vielleicht verstehe ich Ihre Frage falsch: Sie können sie gerne in einem Kommentar klarstellen.)

  4. Ein eingefrorenes Diagramm kann mit geladen werden tf.import_graph_def(). In diesem Fall sind die Gewichte (normalerweise) in das Diagramm eingebettet, sodass Sie keinen separaten Prüfpunkt laden müssen.

  5. Die Hauptänderung besteht darin, die Namen der Tensoren, die in das Modell eingespeist werden, und die Namen der Tensoren, die aus dem Modell abgerufen werden, zu aktualisieren. In der TensorFlow-Android-Demo entspricht dies den inputNameund outputNameZeichenfolgen, an die übergeben wird TensorFlowClassifier.initializeTensorFlow().

  6. Dies GraphDefist die Programmstruktur, die sich normalerweise während des Trainingsprozesses nicht ändert. Der Checkpoint ist eine Momentaufnahme des Status eines Trainingsprozesses, der sich normalerweise bei jedem Schritt des Trainingsprozesses ändert. Infolgedessen verwendet TensorFlow unterschiedliche Speicherformate für diese Datentypen, und die Low-Level-API bietet verschiedene Möglichkeiten zum Speichern und Laden. Übergeordnete Bibliotheken wie MetaGraphDefBibliotheken, Keras und Skflow bauen auf diesen Mechanismen auf und bieten bequemere Möglichkeiten zum Speichern und Wiederherstellen eines gesamten Modells.


Bedeutet dies, dass die C ++ - API-Dokumentation liegt, wenn sie besagt, dass Sie das mit gespeicherte Diagramm laden tf.train.write_graph()und dann ausführen können?
Mnicky

2
Die C ++ API-Dokumentation lügt nicht, aber es fehlen einige Details. Das wichtigste Detail ist, dass Sie sich zusätzlich zu den GraphDefgespeicherten von tf.train.write_graph()auch die Namen der Tensoren merken müssen, die Sie beim Ausführen des Diagramms füttern und abrufen möchten (Punkt 5 oben).
mrry

@mrry: Ich habe versucht, das DeepDream-Beispiel für Tensorflows zu verwenden. aber es scheint, dass es vorgefertigte Modelle im pb-Format braucht! Ich habe das Cifar10-Beispiel ausgeführt, aber es werden nur Prüfpunkte erstellt! Ich konnte keine pb-Dateien finden oder was auch immer! Wie kann ich meine Checkpoints in das pb-Format konvertieren, das im Deepdream-Beispiel verwendet wird?
Rika

2
@ Coderx7 Ich denke wirklich, dass Sie eine .ckpt nicht in eine .pb konvertieren können, da der Prüfpunkt nur die Gewichte und Variablen enthält und nichts über die Struktur des Graphen
weiß

1
Gibt es einen einfachen Code, um eine .pb-Datei zu laden und dann auszuführen?
Kong

1

Sie können den folgenden Code ausprobieren:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.