2016-04-25 1 views
0

ich auf einem Textklassifikation Problem arbeite, eine Pipeline verwenden, die wie folgt aussieht:Wie kann ich tun, um eine Scikit-Learn Rastersuche mit Trainingsdaten, die ein Iterator ist

self.full_classifier = Pipeline([ 
     ('vectorize', CountVectorizer()), 
     ('tf-idf', TfidfTransformer()), 
     ('classifier', SVC(kernel='linear', class_weight='balanced')) 
    ]) 

Der vollständige Korpus zu groß ist um in den Speicher zu passen, aber klein genug, dass nach dem Vektorisierungsschritt keine Speicherprobleme auftreten. Ich kann erfolgreich einen Klassifikator passen durch

self.full_classifier.fit(
     self._all_data (max_samples=train_data_length), 
     self.dataset.head(train_data_length)['target'].values 
) 

mit dem self._all_data ein Iterator ist, der die Dokumente pro Trainingsbeispiel ergibt (während self.dataset nur Dokument-IDs und Ziele enthält). Hier ist max_samples optional, ich verwende es, um Trainings-/Testdaten zu teilen. Ich mag jetzt Gridsearch verwenden Parameter zu optimieren, für die ich diesen Code verwende:

parameters = { 
     'vectorize__stop_words': (None, 'english'), 
     'tfidf__use_idf': (True, False), 
     'classifier__class_weight': (None, 'balanced') 
    } 
gridsearch_classifier = GridSearchCV(self.full_classifier, parameters, n_jobs=-1) 
gridsearch_classifier.fit(self._all_data(), self.dataset['target'].values) 

Mein Problem ist, dass diese die folgenden Fehler erzeugt:

TypeError: Expected sequence or array-like, got <type 'generator'> 

mit dem Zurückverfolgungs zeigt auf dem gridsearch_classifier. Fit-Methode (und dann in Scikit-Code, Fehler in _num_samples (x). Da es möglich ist, mit einem Generator als Eingabe zu passen, habe ich mich gefragt, ob es auch eine Möglichkeit gibt, dies mit der Grid-Suche, die ich gerade bin vermisst. Jede Hilfe wird geschätzt!

Antwort

0

Nicht ohne den Generator als Liste zu materialisieren. Während verschiedene Fit-Methoden oft so strukturiert werden können, dass sie jeweils ein Element konsumieren und somit einen Iterator akzeptieren, führt die Grid-Suche zusätzlich eine Kreuzvalidierung durch und generiert CV-Splits der Daten durch Indizierung einer realisierten Menge.

+0

Danke, das macht Sinn. Ich werde eine Liste fälschen, indem ich ein __getitem__ implementiere, das die Datenbank trifft – Leo