2016-04-16 3 views
0

Ich arbeite an einem Spark-Projekt mit scala. Ich möchte ein Modell trainieren, das k_means, gaussian_mixture, logistische Regression, naive_bayes etc. sein kann. Aber ich kann ein generisches Modell nicht als Rückgabetyp definieren. Da die Typen dieser Algorithmen sich unterscheiden wie GaussianMixmentModel, KMeansModel usw. kann ich keinen logischen Weg finden, dieses trainierte Modell zurückzugeben. HierMehrere Typen für eine Variable in Spark mit Hilfe von Scala

ist ein Frieden von Code aus dem Projekt:

model.model_algorithm match { 

     case "k_means" => 

     val model_k_means = k_means(data, parameters) 

     case "gaussian_mixture" => 

     val model_gaussian_mixture = gaussian_mixture(data, parameters) 

     case "logistic_regression" => 

     val model_logistic_regression = logistic_regression(data, parameters) 

} 

So ist es eine Möglichkeit, dies trainierten Modell zurückzukehren oder ein generisches Modell zu definieren, die alle Arten akzeptiert?

+0

was ist es, dass Sie mit dem trainierten Modell _do_ wollen? Diese Klassen erweitern alle 'org.apache.spark.mllib.util.Saveable',' AntRef' und 'Any', sodass Ihre Methode einen dieser Typen zurückgeben kann, aber das wird Ihnen nicht unbedingt helfen. Wenn Sie später Aktion X für diese Ergebnisse ausführen möchten, möchten Sie möglicherweise eine Eigenschaft 'ModelResult' mit der Methode X erstellen, diese Mustererkennung mit 'ModelResult' zurückgeben und drei Implementierungen dieser Eigenschaft ausführen, von denen jede ein anderes Modell behandelt. –

+0

Ich habe versucht, sie vom Typ Any zu machen, aber vorherzusagen() -Methode kann in diesem Fall nicht verwendet werden. Können Sie mir bitte erklären, wie ich in diesem Fall Mustervergleiche implementieren kann. Vielen Dank für Ihre Antwort. –

+0

Sie haben also tatsächlich drei Modelle und Pattern-Matching initiiert, um zu wissen, welcher läuft. Wenn dies der Fall ist, ist es eine schlechte Übung. – eliasah

Antwort

1

Sie können eine gemeinsame Schnittstelle erstellen, um all Ihre interne Logik des Trainings und der Vorhersage zu umhüllen und einfach eine einfache Schnittstelle freizulegen, die wiederverwendet werden kann.

trait AlgorithmInterface extends Serializable { 
    def train(data: RDD[LabeledPoint]) 
    def predict(record: Vector) 
} 

und Algorithmen wie

in Klassen implementiert haben
class LogisticRegressionAlgorithm extends AlgorithmInterface { 
    var model:LogisticRegressionModel = null 
    override def train(data: RDD[LabeledPoint]): Unit = { 
    model = new LogisticRegressionWithLBFGS() 
     .setNumClasses(10) 
     .run(data) 
    } 
    override def predict(record:Vector): Double = model.predict(record) 
} 

class GaussianMixtureAlgorithm extends AlgorithmInterface { 
    var model: GaussianMixtureModel = null 
    override def train(data: RDD[LabeledPoint]): Unit = { 
    model = new GaussianMixture().setK(2).run(data.map(_.features)) 
    } 
    override def predict(record: Vector) = model.predict(record) 
} 

Implementierung es

// Assigning the models to an Array[AlgorithmInterface] 
    val models: Array[AlgorithmInterface] = Array(
     new LogisticRegressionAlgorithm(), 
     new GaussianMixtureAlgorithm() 
    ) 
    // Training the Models using the Interfaces Train Function 
    models.foreach(_.train(data)) 
    //Predicting the Value 
    models.foreach(model=> println(model.predict(vectorData)))