7

Ich verwende eine Warteschlange, um meine Trainingsbeispiele mithilfe des unten stehenden Codes in mein Netzwerk zu leiten, und es funktioniert ordnungsgemäß.Testen eines Netzwerks während des Trainings im Tensorflow bei Verwendung einer Warteschlange

Allerdings würde ich in der Lage sein mag einige Testdaten zu füttern alle n Iterationen, aber ich weiß nicht wirklich, wie ich vorgehen sollte. Soll ich kurzzeitig die Warteschlange anhalten und die Testdaten manuell einspeisen? Soll ich eine andere Warteschlange nur zum Testen von Daten erstellen?

Edit: ist der richtige Weg, es zu tun eine separate Datei zu erstellen, sagt eval.py, dass kontinuierlich den letzten Checkpoint liest und wertet das Netzwerk? So machen sie das im CIFAR10-Beispiel.

batch = 128 # size of the batch 
x = tf.placeholder("float32", [None, n_steps, n_input]) 
y = tf.placeholder("float32", [None, n_classes]) 

queue = tf.RandomShuffleQueue(capacity=4*batch, 
         min_after_dequeue=3*batch, 
         dtypes=[tf.float32, tf.float32], 
         shapes=[[n_steps, n_input], [n_classes]]) 
enqueue_op = queue.enqueue_many([x, y]) 
X_batch, Y_batch = queue.dequeue_many(batch) 

sess = tf.Session() 

def load_and_enqueue(data): 
    while True: 
     X, Y = data.get_next_batch(batch) 
     sess.run(enqueue_op, feed_dict={x: X, y: Y}) 

train_thread = threading.Thread(target=load_and_enqueue, args=(data)) 
train_thread.daemon = True 
train_thread.start() 

for _ in xrange(max_iter): 
    sess.run(train_op) 
+0

Es gibt einige gute High-Level-Funktionen für diese, die kürzlich hinzugefügt wurden [Github-Repository] (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/ python/schlank/evaluation.py). Sie basieren auf laufenden Auswertungen mit einer separaten ausführbaren Datei, die die durch das Training erstellten Prüfpunktdateien liest. – user728291

+0

@ user728291, gibt es ein Beispiel, um es innerhalb desselben Skripts zu tun? Es scheint, dass andere Tools wie Caffe es einfach so machen. –

+0

Wie wäre es mit zwei Warteschlangen (oder einer Warteschlange und einem gefütterten Platzhalter) und "tf.where" zu verwenden, um zu entscheiden, welche dieser beiden Quellen verwendet wird, um das Netzwerk zu versorgen? –

Antwort

-1

Sie können ein eval_op in Ihren Codes hinzufügen und dann die Auswertung in jedem n (sagen wir n = 1000) Iterationen durchführen. Ein Beispiel hierfür ist wie folgt:

for niter in xrange(max_iter): 
    sess.run(train_op) 
    if niter % 1000 == 0: 
     sess.run(eval_op) 
1

Sie einen weiteren Test Queue und eine Kopie des Trainingsmodell bulid kann als Testmodell wie folgt aus:

trainX, trainY = Queue0(batchSize, ...)... 
testX, testY= Queue1(batchSize, ...)... 
modelTrain = inference(trainX, trainY, ...) 
# reuse variables 
modelTest = inference(testX, testY, ...) 
sess.run(train_op,loss_op,trainX,trainY) 
sess.run(test_op,testX,testY) 

Auf diese Weise können mehr Speicher verbrauchen, da zwei Modelle initialisiert, hoffe, bessere Lösung zu sehen