2016-06-24 16 views
1

Ich habe Code in Tensorflow geschrieben, um den Bearbeitungsabstand zwischen einer Zeichenfolge und einer Gruppe von Zeichenfolgen zu berechnen. Ich kann den Fehler nicht herausfinden.Computing Edit Entfernung (feed_dict error)

import tensorflow as tf 
sess = tf.Session() 

# Create input data 
test_string = ['foo'] 
ref_strings = ['food', 'bar'] 

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return(tf.SparseTensor(indices, chars, [num_words,1,1])) 


test_string_sparse = create_sparse_vec(test_string*len(ref_strings)) 
ref_string_sparse = create_sparse_vec(ref_strings) 

sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True)) 

Dieser Code funktioniert und wenn ausführen, erzeugt es die Ausgabe:

array([[ 0.25], 
     [ 1. ]], dtype=float32) 

Aber wenn ich versuche, dies zu tun, indem sie die spärlichen Tensoren in durch spärliche Platzhalter Fütterung, erhalte ich einen Fehler. Hier

test_input = tf.sparse_placeholder(dtype=tf.string) 
ref_input = tf.sparse_placeholder(dtype=tf.string) 

edit_distances = tf.edit_distance(test_input, ref_input, normalize=True) 

feed_dict = {test_input: test_string_sparse, 
      ref_input: ref_string_sparse} 

sess.run(edit_distances, feed_dict=feed_dict) 

ist der Fehler Zurückverfolgungs:

Traceback (most recent call last): 

    File "<ipython-input-29-4e06de0b7af3>", line 1, in <module> 
    sess.run(edit_distances, feed_dict=feed_dict) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run 
run_metadata_ptr) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run 
    for subfeed, subfeed_val in _feed_fn(feed, feed_val): 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn 
    return feed_fn(feed, feed_val) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda> 
    [feed.indices, feed.values, feed.shape], feed_val)), 

TypeError: zip argument #2 must support iteration 

Jede Idee, was hier vor sich geht?

+0

der Fehler kommt wahrscheinlich aus dem Wert 'test_string_parse' oder' ref_string_parse', können Sie den Code für ihre Kreation bieten –

Antwort

2

TL; DR: Für den Rückgabetyp create_sparse_vec() verwenden tf.SparseTensorValue statt tf.SparseTensor.

Das Problem kommt hier aus dem Rückgabetyp create_sparse_vec(), die tf.SparseTensor ist, und ist nicht als Futter Wert im Aufruf von sess.run() verstanden.

Wenn Sie einen (dichten) Wert tf.Tensor eingeben, ist der Typ des erwarteten Werts ein NumPy-Array (oder bestimmte Objekte, die in ein Array konvertiert werden können). Wenn Sie einen tf.SparseTensor Feed einspeisen, ist der erwartete Wert ein , der einem tf.SparseTensor ähnelt, aber seine indices, values und shape Eigenschaften sind NumPy-Arrays (oder bestimmte Objekte, die in Arrays konvertiert werden können, wie die Listen in Ihrem Beispiel).

sollte der folgende Code arbeiten:

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return tf.SparseTensorValue(indices, chars, [num_words,1,1]) 
+0

Ihnen danken, die perfekt funktioniert?!. – nfmcclure