-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLinearRegression.java
More file actions
105 lines (84 loc) · 3.01 KB
/
Copy pathLinearRegression.java
File metadata and controls
105 lines (84 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;
public class LinearRegression {
public static void main(String[] args) {
try{
double[][] data = readPizzaData("pizza.txt");
double[] X = data[0];
double[] Y = data[1];
LinearRegression model = new LinearRegression();
int iterations = 20000;
double lr = 0.001;
double[] params = model.train(X, Y, iterations, lr);
System.out.printf("\nw = %.10f, b = %.10f\n", params[0], params[1]);
System.out.printf("Prediction: x = %d => y = %.2f%n", 20, model.predict(20, params[0], params[1]));
} catch (FileNotFoundException e){
System.out.println("File not found: " + e.getMessage());
}
}
public static double [][] readPizzaData(String file) throws FileNotFoundException{
Scanner sc = new Scanner(new File(file));
if (sc.hasNextLine()) sc.nextLine();
ArrayList<Double> Reservations = new ArrayList<>();
ArrayList<Double> Pizzas = new ArrayList<>();
while(sc.hasNext()){
if(!sc.hasNextDouble()){
sc.next();
continue;
}
double r = sc.nextDouble();
if(!sc.hasNextDouble()){
break;
}
double p = sc.nextDouble();
Reservations.add(r);
Pizzas.add(p);
}
sc.close();
double[] X = new double[Reservations.size()];
double[] Y = new double[Pizzas.size()];
for(int i = 0; i < Reservations.size(); i++){
X[i] = Reservations.get(i);
Y[i] = Pizzas.get(i);
}
return new double[][] {X, Y};
}
public double predict(double x, double w, double b){
return x * w + b;
}
public double loss(double[] X, double[] Y, double w, double b){
double sum = 0.0;
int n = X.length;
for (int i = 0; i < n; i++) {
double err = predict(X[i], w, b) - Y[i];
sum += err * err;
}
return sum / n;
}
public double[] gradient(double[] X, double[] Y, double w, double b){
double sumWXErr = 0.0;
double sumErr = 0.0;
int n = X.length;
for (int i = 0; i < n; i++) {
double err = (X[i] * w + b) - Y[i];
sumWXErr += X[i] * err;
sumErr += err;
}
double wGrad = 2.0 * (sumWXErr / n);
double bGrad = 2.0 * (sumErr / n);
return new double[] {wGrad, bGrad};
}
public double[] train(double[] X, double[] Y, int iterations, double lr){
double w = 0.0;
double b = 0.0;
for(int i = 0; i < iterations; i++){
System.out.printf("Iteration %4d => Loss: %.10f\n", i, loss(X, Y, w, b));
double[] grads = gradient(X, Y, w, b);
w -= lr * grads[0];
b -= lr * grads[1];
}
return new double[] {w,b};
}
}