Ich habe vor kurzem die AI-Klasse bei Coursera gestartet und ich habe eine Frage bezüglich meiner Implementierung des Gradientenabstiegsalgorithmus.Gradientenabstieg in Java
meine aktuelle Implementierung hier (ich gerade tatsächlich „übersetzt“ die mathematischen Ausdrücke in Java-Code):
public class GradientDescent {
private static final double TOLERANCE = 1E-11;
private double theta0;
private double theta1;
public double getTheta0() {
return theta0;
}
public double getTheta1() {
return theta1;
}
public GradientDescent(double theta0, double theta1) {
this.theta0 = theta0;
this.theta1 = theta1;
}
public double getHypothesisResult(double x){
return theta0 + theta1*x;
}
private double getResult(double[][] trainingData, boolean enableFactor){
double result = 0;
for (int i = 0; i < trainingData.length; i++) {
result = (getHypothesisResult(trainingData[i][0]) - trainingData[i][1]);
if (enableFactor) result = result*trainingData[i][0];
}
return result;
}
public void train(double learningRate, double[][] trainingData){
int iteration = 0;
double delta0, delta1;
do{
iteration++;
System.out.println("SUBS: " + (learningRate*((double) 1/trainingData.length))*getResult(trainingData, false));
double temp0 = theta0 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, false));
double temp1 = theta1 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, true));
delta0 = theta0-temp0; delta1 = theta1-temp1;
theta0 = temp0; theta1 = temp1;
}while((Math.abs(delta0) + Math.abs(delta1)) > TOLERANCE);
System.out.println(iteration);
}
}
Der Code funktioniert ganz gut, aber nur, wenn ich eine sehr kleine Alpha wählen, hier als lernkurs bezeichnet. Wenn es höher als 0,00001 ist, divergiert es.
Haben Sie Vorschläge zur Optimierung der Implementierung oder eine Erklärung für das "Alpha-Issue" und eine mögliche Lösung dafür?
Update:
Hier ist der Haupt darunter auch einige Beispieleingänge:
private static final double[][] TDATA = {{200, 20000},{300, 41000},{900, 141000},{800, 41000},{400, 51000},{500, 61500}};
public static void main(String[] args) {
GradientDescent gd = new GradientDescent(0,0);
gd.train(0.00001, TDATA);
System.out.println("THETA0: " + gd.getTheta0() + " - THETA1: " + gd.getTheta1());
System.out.println("PREDICTION: " + gd.getHypothesisResult(300));
}
Der mathematische Ausdruck von Gradientenabfallsaktualisierung sich wie folgt:
Sie sollten wahrscheinlich einige Beispiel Daten/Eingabe, in einer 'main' Methode, und vielleicht einen Backlink zu den Foren, die Sie" übersetzt "bereitstellen. – Marco13
Ich habe die Frage aktualisiert und ich habe auch ein kleines Problem gefunden. Nachdem ich es behoben habe, kann ich nun die Lernrate auf 0,0001 setzen. Aber ich denke, es ist immer noch ziemlich niedrig, aber viel besser als vorher. –
Welche der Coursera ML/AI-Klassen war es? Stanfords maschinelles Lernen? –