2012-03-30 4 views
3

Ich möchte Parameter C und Gamma für C-SVM-Klassifizierung mit der RBF (radiale Basisfunktion) Kernel mit libsvm \ tools \ grid.py auswählen, aber ich weiß nicht, wie es möglich ist? Ich installierte libsvm und gnuplot und Python und lief die grid.py in Python, aber es hatte einen Fehler und zeigte die Ergebnisse nicht.Wie kann ich grid.py für die Parameterauswahl verwenden?

Antwort

12
%grid of parameters 
folds = 5; 
[C,gamma] = meshgrid(-5:2:15, -15:2:3); 
%# grid search, and cross-validation 
cv_acc = zeros(numel(C),1); 
d= 2; 
for i=1:numel(C) 
    cv_acc(i) = svmtrain(TrainLabel,TrainVec, ...   
     sprintf('-c %f -g %f -v %d -t %d', 2^C(i), 2^gamma(i), folds,d)); 
end 
%# pair (C,gamma) with best accuracy 
[~,idx] = max(cv_acc); 
%# contour plot of paramter selection 
contour(C, gamma, reshape(cv_acc,size(C))), colorbar 
hold on; 
text(C(idx), gamma(idx), sprintf('Acc = %.2f %%',cv_acc(idx)), ... 
    'HorizontalAlign','left', 'VerticalAlign','top') 
hold off 
xlabel('log_2(C)'), ylabel('log_2(\gamma)'), title('Cross-Validation Accuracy') 
%# now you can train you model using best_C and best_gamma 
best_C = 2^C(idx); best_gamma = 2^gamma(idx); %# ... 

Dies führt Rastersuche als auch ... aber Matlab ... nicht grid.py mit ... vielleicht hilft das ...

+0

Sorry, ich verstehe nicht, den Rückgabewert der Funktion svmtrain. Es sollte ein Modell sein, das zurückgegeben wird. Warum behandeln Sie es als eine Genauigkeit und welche Genauigkeit ist es? Vielen Dank! –

+0

@lakesh: Hier, für dieses Gitter Suche und die Auswahl der Parameter, muss es alle Daten (Zug und Test) als Eingabe in den "Svmtrain" Feed oder gibt es eine andere Möglichkeit, mit Daten umzugehen? – Amin

6

Sie die Matlab-Skript statt Gitter versehen verwenden könnte .py FAQ

Frage: Wie kann ich die MATLAB-Schnittstelle für die Parameterauswahl verwenden? http://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#f803

bestcv = 0; 
for log2c = -1:3, 
    for log2g = -4:1, 
    cmd = ['-v 5 -c ', num2str(2^log2c), ' -g ', num2str(2^log2g)]; 
    cv = svmtrain(heart_scale_label, heart_scale_inst, cmd); 
    if (cv >= bestcv), 
     bestcv = cv; bestc = 2^log2c; bestg = 2^log2g; 
    end 
    fprintf('%g %g %g (best c=%g, g=%g, rate=%g)\n', log2c, log2g, cv, bestc, bestg, bestcv); 
    end 
end