2016-04-24 25 views
1

las ich versuchen Lenet zu trainieren hier definiert Solving in Python with LeNet die Ziffern-Erkennungsdaten auf Kaggle gesetzt zu trainieren. Ich benutze zuerst das Tutorial Create lmdb, um Daten in lmdb-Format zu übertragen. Dann folge ich den Anweisungen in Link 1 (Lösen in Python mit LeNet), um Trainings-, Test- und Löserprototypen zu konstruieren. Wenn ich Solver aus solver.prototxt extrahiere, habe ich jedoch festgestellt, dass jedes Element in den Bilddaten null ist. Ist mit meinem Code etwas nicht in Ordnung?Python mit Caffe: Die benutzerdefinierten Daten sind alle Nullen, wenn sie von Solver

import pandas as pd 
import lmdb 
import caffe 
import numpy as np 
import numpy as np 
from caffe import layers as L, params as P 
from pylab import * 
import os, sys 
from caffe.proto import caffe_pb2 
%matplotlib inline 

train_original = pd.read_csv(path/to/my/train.csv) 
test = pd.read_csv(path/to/my/test.csv) 
train_obs, dim = train_data.shape 
val_obs, dim = val_data.shape 
train_data_array = np.array(train_data, dtype = float32) 
train_label_array = np.array(train_label, dtype = float32) 
val_data_array = np.array(val_data, dtype = float32) 
val_label_array = np.array(val_label, dtype = float32) 

train_lmdb_size = train_data_array.nbytes * 10 
val_lmdb_size = val_data_array.nbytes * 10 
env = lmdb.open('train_lmdb', map_size=train_lmdb_size) 
with env.begin(write=True) as txn: 
    for i in range(train_num): 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.channels = 1 
     datum.height = 28 
     datum.width = 28 
     datum.data = train_data_array[i].reshape(28, 28).tobytes() # or .tostring() if numpy < 1.9 
     datum.label = int(train_label_array[i]) 
     str_id = '{:08}'.format(i) 
     # The encode is only essential in Python 3 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

env = lmdb.open('test_lmdb', map_size=train_lmdb_size) 
with env.begin(write=True) as txn: 
    for i in range(val_num): 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.channels = 1 
     datum.height = 28 
     datum.width = 28 
     datum.data = val_data_array[i].reshape(28, 28).tobytes() # or .tostring() if numpy < 1.9 
     datum.label = int(val_label_array[i]) 
     str_id = '{:08}'.format(i) 
     # The encode is only essential in Python 3 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

train_path = 'CNN_training.prototxt' 
test_path = 'CNN_testing.prototxt' 
train_lmdb_path = 'train_lmdb' 
test_lmdb_path = 'test_lmdb' 
solver_path = 'CNN_solver.prototxt' 

def lenet(lmdb, batch_size): 
    # our version of LeNet: a series of linear and simple nonlinear transformations 
    n = caffe.NetSpec() 

    n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb, 
          transform_param=dict(scale=1./255), ntop=2) 

    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20, weight_filler=dict(type='xavier')) 
    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX) 
    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50, weight_filler=dict(type='xavier')) 
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX) 
    n.fc1 = L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier')) 
    n.relu1 = L.ReLU(n.fc1, in_place=True) 
    n.score = L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier')) 
    n.loss = L.SoftmaxWithLoss(n.score, n.label) 

    return n.to_proto() 

with open(train_path, 'w') as f: 
    f.write(str(lenet(train_lmdb_path, 64))) 

with open(test_path, 'w') as f: 
    f.write(str(lenet(test_lmdb_path, 100))) 

s = caffe_pb2.SolverParameter() 
s.random_seed = 0xCAFFE 
s.train_net = train_path 
s.test_net.append(test_path) 
s.test_interval = 500 
s.test_iter.append(100) 
s.max_iter = 10000 
s.type = 'Adam' 
s.base_lr = 0.01 
s.momentum = 0.75 
s.weight_decay = 5e-1 
s.lr_policy = 'inv' 
s.gamma = 0.0001 
s.power = 0.75 
s.display = 1000 
s.snapshot = 5000 
s.snapshot_prefix = 'lin_lnet' 
s.solver_mode = caffe_pb2.SolverParameter.CPU 
with open(solver_path,'w') as f: 
    f.write(str(s)) 

solver = None 
solver = caffe.get_solver(solver_path) 
# result in solver.net['data'].data[0] are zeros 
print solver.net['data'].data[0] 
array([[[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.]]], dtype=float32) 

Antwort

1

Versuchen Sie eine net.forward(). Sie sollten Ihre Daten sehen können, wenn alles andere korrekt ist.

Eine einfachere und sicherere Methode zum Schreiben in die LMDB ist die Verwendung von caffe.io.array_to_datum als demonstriert here.

+0

danke, ich sehe es – user3162707

+0

@ user3162707 Bitte beachten Sie "akzeptieren" diese Antwort, indem Sie auf das "v" -Symbol daneben klicken. – Shai