-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMain.java
More file actions
84 lines (62 loc) · 2.69 KB
/
Main.java
File metadata and controls
84 lines (62 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import java.io.IOException;
import java.util.Random;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws IOException {
NeuralNetwork nn = new NeuralNetwork(new int[] {2, 4, 4, 2}, 1.5);
Random rand = new Random();
// NeuralNetwork nn = NeuralNetwork.load("NN_XOR_2_4_3_1");
double[][] inputs = new double[4][2];
double[][] targets = new double[4][2];
inputs[0][0] = 0;
inputs[0][1] = 0;
targets[0][0] = 0;
targets[0][1] = 1;
inputs[1][0] = 1;
inputs[1][1] = 0;
targets[1][0] = 1;
targets[1][1] = 0;
inputs[2][0] = 0;
inputs[2][1] = 1;
targets[2][0] = 1;
targets[2][1] = 0;
inputs[3][0] = 1;
inputs[3][1] = 1;
targets[3][0] = 0;
targets[3][1] = 1;
double[][] inputsArr = new double[100][2];
double[][] targetsArr = new double[100][2];
for (int i = 0; i < 1000; i++) {
for (int j = 0; j < inputsArr.length; j++) {
int index = rand.nextInt(4);
inputsArr[j][0] = inputs[index][0];
inputsArr[j][1] = inputs[index][1];
targetsArr[j][0] = targets[index][0];
targetsArr[j][1] = targets[index][1];
}
if (i % 100 == 0) {
System.out.println("i: " + i);
System.out.println(nn.cost(new double[]{0, 0}, new double[]{0, 1}));
System.out.println(nn.cost(new double[]{1, 0}, new double[]{1, 0}));
System.out.println(nn.cost(new double[]{0, 1}, new double[]{1, 0}));
System.out.println(nn.cost(new double[]{1, 1}, new double[]{0, 1}));
}
nn.epoch(inputsArr, targetsArr, 64);
}
System.out.println("Input: 0, 0\nOutput: " + nn.guess(new double[]{0, 0}));
System.out.println("Input: 1, 0\nOutput: " + nn.guess(new double[]{1, 0}));
System.out.println("Input: 0, 1\nOutput: " + nn.guess(new double[]{0, 1}));
System.out.println("Input: 1, 1\nOutput: " + nn.guess(new double[]{1, 1}));
System.out.println(nn.cost(new double[]{0, 0}, new double[]{0, 1}));
System.out.println(nn.cost(new double[]{1, 0}, new double[]{1, 0}));
System.out.println(nn.cost(new double[]{0, 1}, new double[]{1, 0}));
System.out.println(nn.cost(new double[]{1, 1}, new double[]{0, 1}));
// System.out.println("Do you want to save NN?(y/n): ");
//
// Scanner scan = new Scanner(System.in);
// if (scan.next().equals("y")) {
// System.out.println("Saving...");
// nn.save("NN_XOR_2_4_3_1");
// }
}
}