2016-05-09 8 views
0

Ich benutze Keras Functional API, um ein einfaches Netzwerk mit mehreren Eingängen und mehreren Ausgängen zu realisieren. Aber ein Fehler ist zu mir gekommen und ich kann nicht herausfinden, wie ich ihn lösen soll. hier ist der Code:Keras Assertionsfehler

import numpy as np 
from keras.layers import Dense, Activation, Input, merge, Lambda 
from keras.models import Model 
from keras.optimizers import SGD 

def get_half_1(nparray): 
    return nparray[:,:5] 
def get_half_2(nparray): 
    return nparray[:,5:] 

train_x = np.random.uniform(0.0,1.0,size=(50,12)) 
train_y = np.random.uniform(0.0,1.0,(50,8)) 

x_row, x_col = train_x.shape 
y_row, y_col = train_y.shape 

x_input = Input(shape=(x_row,), name='x_input') 
y_input = Input(shape=(y_row,), name='y_input') 

x_hidden = Dense(5,activation='sigmoid')(x_input) 
y_hidden = Dense(5,activation='sigmoid')(y_input) 

# merge two layers 
com_x = merge([x_hidden, y_hidden],mode='concat') 

feature_layer = Dense(10, activation='sigmoid')(com_x) 

# decoding 
com_x_transpose = Dense(10,activation='sigmoid')(feature_layer) 

x_hidden_transpose = Lambda(get_half_1,output_shape=(50,5)) (com_x_transpose) 
y_hidden_transpose = Lambda(get_half_2,output_shape=(50,5))(com_x_transpose) 

x_recon_error = Dense(12,activation='sigmoid')(x_hidden_transpose) 
y_recon_error = Dense(8,activation='sigmoid')(y_hidden_transpose) 
# 
model = Model(input=[x_input, y_input],output=[x_recon_error, y_recon_error]) 


model.compile(optimizer='rmsprop',loss='mean_square_error') 

model.fit(train_x, train_y,nb_epoch=50,batch_size=50) 

ich diesen Code mit python3 laufen, und ich erhalte die folgenden Fehler:

Traceback (most recent call last): 
    File "splittest.py", line 35, in <module> 
    x_recon_error = Dense(12,activation='sigmoid')(x_hidden_transpose) 
    File "/Users/lw/Library/Python/3.5/lib/python/site- packages/keras/engine/topology.py", line 458, in __call__ 
    self.build(input_shapes[0]) 
    File "/Users/lw/Library/Python/3.5/lib/python/site-packages/keras/layers/core.py", line 583, in build 
    assert len(input_shape) == 2 
AssertionError 
+0

ändern Es ist ein sehr einfaches Beispiel. Mache ich einige Fehler in Bezug auf die variable Dimension? –

Antwort

0

Einfach

x_input = Input(shape=(x_row,), name='x_input') 
y_input = Input(shape=(y_row,), name='y_input') 

zu

x_input = Input(shape=train_x.shape, name='x_input') 
y_input = Input(shape=train_y.shape, name='y_input')