2016-07-18 16 views
2

Ich versuche, ein Modell mit einer .ckpt Datei wiederherzustellen, die ich durch word2vec_optimized.py in tensorflow/models/embedding ausgeführt habe. Ich bin mir nicht sicher, wie ich die Variablen wiederherstellen soll, damit ich das Modell laden und verwenden kann, weil alle tf-Variablen in Klassen in tensorflow/models/embedding/word2vec_optimized.py gekapselt und initialisiert sind. Jede Hilfe wäre willkommen.Wie kann ein Tensorflow-Modell wiederhergestellt werden?

Auch wenn ich ".ckpt erstellt" wiederherstellen, habe ich jetzt eine Wor2Vec Instanz oder was bekomme ich eigentlich, wenn ich ein Modell mit einem .ckpt wiederherstellen?

Antwort

1

Wenn Sie die Speicherfunktion auf Ihrem Speicher aufrufen, übergeben Sie ihm die tf.Session, mit der Sie das Modell trainiert haben. Dies enthält einen Verweis auf das Diagramm, das alle Variablen enthält. Verwechseln Sie Python-Variablen nicht mit Tensorflow-Variablen. Auch wenn Sie in Python keine Variable mehr haben, die auf eine von Ihnen erstellte Tensorflussvariable zeigt, ist sie immer noch vorhanden, wenn sie Teil des Berechnungsgraphen ist. Versuchen Sie nach dem Erstellen des Modells den folgenden Code auszuführen.

for v in tf.all_variables(): 
    print(v.name) 

Dadurch wird der Name jeder von Ihnen erstellten Variablen ausgegeben. Der Sparer wird standardmäßig alle diese speichern. Solange die Variablen denselben Namen haben, wenn Sie sie wiederherstellen, spielt es keine Rolle, wo sie erstellt wurden. Stellen Sie nur sicher, dass Sie die Wiederherstellung durchführen, nachdem alle Variablen zum Modell hinzugefügt wurden. Wenn Sie einer Variablen einen Initialisierer geben, wird die Initialisierung nur ausgeführt, wenn Sie sess.run(tf.initialize_all_variables()) aufrufen. Sie müssen dies nicht aufrufen, wenn Sie nur die Werte wiederherstellen. Ich benutze oft den folgenden Code.

sess = tf.Session() 
saver = tf.train.Saver() 
if 'restore' in sys.argv: 
    saver.restore(sess, '/media/chase/98d61322-9ea7-473e-b835-8739c77d1e1e/model.chk') 
else: 
    sess.run(tf.initialize_all_variables()) 

Dieser Code funktioniert gut, wenn ich die thensorflow RNN Klassen verwenden, welche Variablen in sie erstellen.