2016-05-31 27 views
3

Während ich Tensorflow verwende, versuche ich CIFAR10 Training mit einer Checkpointed-Datei fortzusetzen. Einige andere Artikel referenzierend, versuchte ich tf.train.Saver()., Wiederherstellen ohne Erfolg. Kann mir jemand etwas darüber sagen, wie es weitergehen soll?Tensorflow cifar10 Fortsetzen Training von Checkpoint-Datei

-Code-Schnipsel aus Tensorflow CIFAR10

def train(): 
    # methods to build graph from the cifar10_train.py 
    global_step = tf.Variable(0, trainable=False) 
    images, labels = cifar10.distorted_inputs() 
    logits = cifar10.inference(images) 
    loss = cifar10.loss(logits, labels) 
    train_op = cifar10.train(loss, global_step) 
    saver = tf.train.Saver(tf.all_variables()) 
    summary_op = tf.merge_all_summaries() 

    init = tf.initialize_all_variables() 
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) 
    sess.run(init) 


    print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir) 

    if FLAGS.checkpoint_dir is None: 
    # Start the queue runners. 
    tf.train.start_queue_runners(sess=sess) 
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) 
    else: 
    # restoring from the checkpoint file 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 

    # cur_step prints out well with the checkpointed variable value 
    cur_step = sess.run(global_step); 
    print("current step is %s" % cur_step) 

    for step in xrange(cur_step, FLAGS.max_steps): 
    start_time = time.time() 
    # **It stucks at this call ** 
    _, loss_value = sess.run([train_op, loss]) 
    # below same as original 

Antwort

2

Das Problem scheint, dass diese Linie zu sein:

tf.train.start_queue_runners(sess=sess) 

... nur wenn FLAGS.checkpoint_dir is None ausgeführt wird. Sie müssen die Warteschlange weiterhin starten, wenn Sie von einem Prüfpunkt wiederherstellen.

Bitte beachte, dass ich empfehlen würde man den Warteschlange Läufer nach Erstellen der tf.train.Saver (aufgrund einer Race-Bedingung in der freigegebenen Version des Codes) zu starten, so dass eine bessere Struktur wäre:

if FLAGS.checkpoint_dir is not None: 
    # restoring from the checkpoint file 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 

# Start the queue runners. 
tf.train.start_queue_runners(sess=sess) 

# ... 

for step in xrange(cur_step, FLAGS.max_steps): 
    start_time = time.time() 
    _, loss_value = sess.run([train_op, loss]) 
    # ... 
+0

dankt Du für die Antwort! Es hat das Problem gelöst. Ich dachte, queue_runner ist verantwortlich für die Erstellung des Eingangsbildes (durch Verzerrung) und es ist kein notwendiger Schritt, wie ich aus einer Prüfpunktdatei wiederherstellen. – emerson