2016-07-15 15 views
1

Ich habe einen Datensatz mit einer Zielvariablen, die 7 verschiedene Beschriftungen haben kann. Jedes Sample in meinem Trainingssatz hat nur eine Bezeichnung für die Zielvariable.Sklearn - Wie die Wahrscheinlichkeit für alle Zielmarken vorhergesagt wird

Für jede Probe möchte ich die Wahrscheinlichkeit für jede der Zielbeschriftungen berechnen. Also würde meine Vorhersage aus 7 Wahrscheinlichkeiten für jede Reihe bestehen.

Auf der Sklearn-Website habe ich über Multi-Label-Klassifizierung gelesen, aber das scheint nicht, was ich will.

Ich habe den folgenden Code versucht, aber das gibt mir nur eine Klassifizierung pro Probe.

from sklearn.multiclass import OneVsRestClassifier 
clf = OneVsRestClassifier(DecisionTreeClassifier()) 
clf.fit(X_train, y_train) 
pred = clf.predict(X_test) 

Hat jemand dazu einen Rat? Vielen Dank!

Antwort

1

Sie können das tun, indem Sie einfach die OneVsRestClassifer entfernen und predict_proba Methode der DecisionTreeClassifier verwenden. Sie können Folgendes tun:

Dies gibt Ihnen eine Wahrscheinlichkeit für jede Ihrer 7 möglichen Klassen.

Hoffe, dass hilft!

2

Sie können versuchen, scikit-multilearn - eine Erweiterung von sklearn, die Multilabel-Klassifizierung behandelt. Wenn Sie Ihre Etiketten sind nicht übermäßig korreliert können Sie trainieren ein Klassifikator pro Etikett und alle Prognosen bekommen - versuchen (nach pip scikit-multilearn installieren):

from skmultilearn.problem_transform import BinaryRelevance  
classifier = BinaryRelevance(classifier = DecisionTreeClassifier()) 

# train 
classifier.fit(X_train, y_train) 

# predict 
predictions = classifier.predict(X_test) 

Prognosen eine spärliche Matrix der Größe enthalten (N_SAMPLES, n_labels) In Ihrem Fall - n_Labels = 7, enthält jede Spalte eine Vorhersage pro Label für alle Proben.

Falls Ihre Etiketten korreliert sind, benötigen Sie möglicherweise komplexere Methoden zur Klassifizierung mehrerer Labels.

Haftungsausschluss: Ich bin der Autor von scikit-multilearn, zögern Sie nicht, weitere Fragen zu stellen.

+0

In der kürzlich veröffentlichten Version 0.0.4 von scikit-multilearn finden Sie predict_proba-Implementierungen für problemtransformationsbasierte Multi-Label-Klassifikatormethoden. Im obigen Fall ersetzen Sie die letzte Zeile durch: predictions = classifier.predict_proba (X_test) – niedakh