2015-11-23 5 views
22

Ich versuche, meine eigene RNNCell (Echo State Network) in Tensorflow zu definieren, nach unten Definition.Wie kann ich eine benutzerdefinierte RNN (speziell eine ESN) in Tensorflow implementieren?

x (t + 1) = tanh (Win * u (t) + W * x (t) + Wfb * y (t))

y (t) = Wout * z (t)

z (t) = [x (t), u (t)] ist

x Zustand, u-Eingang ist, y ausgegeben. Win, W und Wfb sind nicht trainierbar. Alle Gewichte werden zufällig initialisiert, aber W wird wie folgt geändert:. „Einen bestimmten Prozentsatz von Elementen von W 0 ist, Skala W unter 1,0

seine spektralen Radius zu halten ich diesen Code haben, um die Gleichung zu erzeugen

x = tf.Variable(tf.reshape(tf.zeros([N]), [-1, N]), trainable=False, name="state_vector") 
W = tf.Variable(tf.random_normal([N, N], 0.0, 0.05), trainable=False) 
# TODO: setup W according to the ESN paper 
W_x = tf.matmul(x, W) 

u = tf.placeholder("float", [None, K], name="input_vector") 
W_in = tf.Variable(tf.random_normal([K, N], 0.0, 0.05), trainable=False) 
W_in_u = tf.matmul(u, W_in) 

z = tf.concat(1, [x, u]) 
W_out = tf.Variable(tf.random_normal([K + N, L], 0.0, 0.05)) 
y = tf.matmul(z, W_out) 
W_fb = tf.Variable(tf.random_normal([L, N], 0.0, 0.05), trainable=False) 
W_fb_y = tf.matmul(y, W_fb) 

x_next = tf.tanh(W_in_u + W_x + W_fb_y) 

y_ = tf.placeholder("float", [None, L], name="train_output") 

Mein Problem ist zweifach. Erstens weiß ich nicht, wie das von RNNCell als Superklasse zu implementieren. Zweitens weiß ich nicht, wie man eine W-Tensor nach oben Spezifikation zu erzeugen.

Jede Hilfe über jede dieser Fragen wird sehr geschätzt, vielleicht kann ich einen Weg finden, W vorzubereiten, aber ich verstehe ganz bestimmt nicht, wie ich meine eigenen umsetzen soll RNN als eine Oberklasse von RNNCell.

Antwort

10

Um eine kurze Zusammenfassung zu geben:

Blick in den TensorFlow Quellcode unter python/ops/rnn_cell.py zu sehen, wie man Unterklasse RNNCell. Es ist normalerweise so:

class MyRNNCell(RNNCell): 
    def __init__(...): 

    @property 
    def output_size(self): 
    ... 

    @property 
    def state_size(self): 
    ... 

    def __call__(self, input_, state, name=None): 
    ... your per-step iteration here ...