Unterschied zwischen Variable und get_variable in TensorFlow


125

Soweit ich weiß, Variableist dies die Standardoperation zum Erstellen einer Variablen und get_variablewird hauptsächlich zum Teilen von Gewichten verwendet.

Auf der einen Seite schlagen einige Leute vor, get_variableanstelle der primitiven VariableOperation zu verwenden, wann immer Sie eine Variable benötigen. Andererseits sehe ich lediglich eine Verwendung get_variablein den offiziellen Dokumenten und Demos von TensorFlow.

Daher möchte ich einige Faustregeln zur korrekten Verwendung dieser beiden Mechanismen kennen. Gibt es "Standard" -Prinzipien?


6
get_variable ist ein neuer Weg, Variable ist ein alter Weg (der für immer unterstützt werden könnte), wie Lukasz sagt (PS: Er hat einen Großteil des Variablennamens in TF geschrieben)
Yaroslav Bulatov

Antworten:


90

Ich würde empfehlen, immer zu verwenden tf.get_variable(...)- es erleichtert die Umgestaltung Ihres Codes, wenn Sie Variablen jederzeit gemeinsam nutzen müssen, z. B. in einer Multi-GPU-Einstellung (siehe das Multi-GPU-CIFAR-Beispiel). Es gibt keinen Nachteil.

Rein tf.Variableist niedriger; Irgendwann tf.get_variable()existierte es nicht mehr, so dass einige Codes immer noch den Low-Level-Weg verwenden.


5
Vielen Dank für Ihre Antwort. Aber ich habe noch eine Frage , wie zu ersetzen tf.Variablemit tf.get_variableüberall. Wenn ich dann eine Variable mit einem Numpy-Array initialisieren möchte, kann ich keine saubere und effiziente Methode finden, wie ich es tue tf.Variable. Wie lösen Sie es? Vielen Dank.
Lifu Huang

68

tf.Variable ist eine Klasse, und es gibt verschiedene Möglichkeiten, tf.Variable einschließlich tf.Variable.__init__und zu erstellen tf.get_variable.

tf.Variable.__init__: Erstellt eine neue Variable mit initial_value .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Ruft eine vorhandene Variable mit diesen Parametern ab oder erstellt eine neue. Sie können auch den Initialisierer verwenden.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Es ist sehr nützlich, Initialisierer wie die folgenden zu verwenden xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Weitere Informationen hier .


Ja, Variableeigentlich meine ich es zu benutzen __init__. Da dies get_variableso praktisch ist, frage ich mich, warum die meisten TensorFlow-Codes, die ich gesehen habe, Variablestattdessen verwendet wurden get_variable. Gibt es Konventionen oder Faktoren, die bei der Auswahl berücksichtigt werden müssen? Danke dir!
Lifu Huang

Wenn Sie einen bestimmten Wert haben möchten, ist die Verwendung von Variable einfach: x = tf.Variable (3).
Sung Kim

@SungKim Normalerweise können wir es bei Verwendung tf.Variable()als Zufallswert aus einer abgeschnittenen Normalverteilung initialisieren. Hier ist mein Beispiel w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). Was wäre das Äquivalent dazu? Wie kann ich sagen, dass ich eine abgeschnittene Normalität möchte? Soll ich es einfach tun w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: Mit können tf.truncated_normal_initializer()Sie das gewünschte Ergebnis erzielen .
Beta

46

Ich kann zwei Hauptunterschiede zwischen dem einen und dem anderen finden:

  1. Erstens tf.Variablewird immer eine neue Variable erstellt, während tf.get_variableeine vorhandene Variable mit angegebenen Parametern aus dem Diagramm abgerufen wird. Wenn diese nicht vorhanden ist, wird eine neue Variable erstellt.

  2. tf.Variable erfordert die Angabe eines Anfangswertes.

Es ist wichtig zu verdeutlichen, dass die Funktion tf.get_variabledem Namen den aktuellen Variablenbereich voranstellt, um Wiederverwendungsprüfungen durchzuführen. Beispielsweise:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

Der letzte Assertionsfehler ist interessant: Zwei Variablen mit demselben Namen im selben Bereich sollen dieselbe Variable sein. Wenn Sie jedoch die Namen von Variablen testen dund efeststellen, dass Tensorflow den Namen der Variablen geändert hat e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Tolles Beispiel! In Bezug auf d.nameund e.namebin ich gerade auf dieses TensorFlow-Dokument über die Benennung von TensordiagrammenIf the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
gestoßen

2

Ein weiterer Unterschied besteht darin, dass einer in der ('variable_store',)Sammlung ist, der andere jedoch nicht.

Bitte beachten Sie den Quellcode :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Lassen Sie mich das veranschaulichen:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Die Ausgabe:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.