Ich habe eine dichte Dropout-ANN mit einer Softmax-Ausgabeschicht. Hier ist die Trainingsmethode:Datentyp Inkompatibilität Problem in Theano
def train(network, input_var, epochs, train_input, val_input, batchsize,
update_fn, loss_fn, verbose=True, deterministic=False, **kwargs):
"""
:param network: the output layer of a `lasagne`-backed ANN
:type input_var: TheanoVariable
:param train_input: (x, y)
:type train_input: (np.ndarray, np.ndarray)
:param val_input: (x, y)
:type val_input: (np.ndarray, np.ndarray)
"""
# create target var
# note: I use my own method instead of `theano.shared`, because for
# whatever reason Theano says I can't use a shared variable here
# and that I should pass it via the `givens` parameter, whatever
# that is.
target_var = self.numpy_to_theano_variable(train_input[1])
# training functions
prediction = lasagne.layers.get_output(network,
deterministic=deterministic)
loss = loss_fn(prediction, target_var).mean()
params = lasagne.layers.get_all_params(network, trainable=True)
updates = update_fn(loss, params, **kwargs)
train_fn = theano.function([input_var, target_var], loss, updates=updates)
# validation functions
val_pred = lasagne.layers.get_output(network, deterministic=True)
val_loss = loss_fn(val_pred, target_var).mean()
val_acc = T.mean(T.eq(T.argmax(val_pred, axis=1), target_var),
dtype=theano.config.floatX)
val_fn = theano.function([input_var, target_var], [val_loss, val_acc])
def run_epoch(epoch):
train_batches = yield_batches(train_input, batchsize)
val_batches = yield_batches(val_input, batchsize)
train_err = np.mean([train_fn(x, y) for x, y in train_batches])
val_err, val_acc = np.mean(
[val_fn(x, y) for x, y in val_batches], axis=0)
if verbose:
print("Epoch {} of {}: training error = {}, "
"validation error = {}, validation accuracy = {}"
"".format(epoch+1, epochs, train_err, val_err, val_acc))
return train_err, val_err, val_acc
return [run_epoch(e) for e in xrange(epochs)]
Die numpy_to_theano_variable
Verfahren in der Basisklasse definiert sind:
def create_theano_variable(ndim, dtype, name=None):
"""
:type ndim: int
:type dtype: str
:type name: str
"""
if ndim == 1:
theano_var = T.vector(name, dtype=dtype)
elif ndim == 2:
theano_var = T.matrix(name, dtype=dtype)
elif ndim == 3:
theano_var = T.tensor3(name, dtype=dtype)
elif ndim == 4:
theano_var = T.tensor4(name, dtype=dtype)
else:
raise ValueError
return theano_var
def numpy_to_theano_variable(array, name=None):
"""
:type array: np.ndarray
:param array:
:rtype: T.TensorVariable
"""
return create_theano_variable(ndim=array.ndim,
dtype=str(array.dtype).split(".")[-1],
name=name)
Zu Beginn der train
target_var
als TheanoVariable
mit der gleichen Anzahl von Dimensionen initialisiert wird, und gibt als das Nummernfeld, mit dem es gespeist wird. Aus einem Grund für mich unbegreiflich, wenn der Datentyp nicht int32
oder int64
ist bekomme ich diesen Fehler:
Traceback (most recent call last):
File "./train_net.py", line 131, in <module>
main(sys.argv[1:])
File "./train_net.py", line 123, in main
learning_rate=learning_rate, momentum=momentum, verbose=True)
File "/Users/ilia/OneDrive/GitHub/...", line 338, in train
loss = loss_fn(prediction, target_var).mean()
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/lasagne/objectives.py", line 129, in categorical_crossentropy
return theano.tensor.nnet.categorical_crossentropy(predictions, targets)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/tensor/nnet/nnet.py", line 2077, in categorical_crossentropy
return crossentropy_categorical_1hot(coding_dist, true_dist)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/gof/op.py", line 613, in __call__
node = self.make_node(*inputs, **kwargs)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/tensor/nnet/nnet.py", line 1440, in make_node
tensor.lvector))
TypeError: integer vector required for argument: true_one_of_n(got type: TensorType(<dtype>, vector) instead of: TensorType(int64, vector))
wo <dtype>
die Art der target_var
aus dem numpy Array abgeleitet darstellt (I getestet, dass mit int8
, int16
, uint8
, uint16
, uint32
, uint64
). Was ist der Grund, warum es nur int32
und int64
dauert?