diff --git a/Driver.java b/Driver.java deleted file mode 100644 index e58022c..0000000 --- a/Driver.java +++ /dev/null @@ -1,22 +0,0 @@ -import java.util.List; - -public class Driver { - - static double[][] X = { { 0, 0 }, { 1, 0 }, { 0, 1 }, { 1, 1 } }; - static double[][] Y = { { 0 }, { 1 }, { 1 }, { 0 } }; - - public static void main(String[] args) { - - NeuralNetwork nn = new NeuralNetwork(2, 10, 1, 0.01, true); - - List output; - nn.fit(X, Y, 500, 0); - double[][] input = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } }; - for (double d[] : input) { - output = nn.predict(d); - System.out.println(output.toString()); - } - - } - -} \ No newline at end of file diff --git a/Neural_Network/Main.java b/Neural_Network/Main.java new file mode 100644 index 0000000..760fdc7 --- /dev/null +++ b/Neural_Network/Main.java @@ -0,0 +1,16 @@ +class Main { + public static void main(String [] args) { + NeuralNetwork nn = new NeuralNetwork(2, 10, 1); + + double[][] training_data_in = nn.getFileData("training_data_in.txt"); + double[][] training_data_out = nn.getFileData("training_data_out.txt"); + + double[][] testing_data_in = nn.getFileData("testing_data_in.txt"); + double[][] testing_data_out = nn.getFileData("testing_data_out.txt"); + + int epochs = 50000; + nn.fit(training_data_in,training_data_out,epochs,0); + + nn.testData(training_data_in,training_data_out); + } +} diff --git a/NeuralNetwork.java b/Neural_Network/NeuralNetwork.java similarity index 74% rename from NeuralNetwork.java rename to Neural_Network/NeuralNetwork.java index f52704a..a7f3ce0 100644 --- a/NeuralNetwork.java +++ b/Neural_Network/NeuralNetwork.java @@ -1,9 +1,14 @@ - +import java.io.FileNotFoundException; import java.util.ArrayList; +import java.io.PrintWriter; +import java.util.Scanner; +import java.util.Arrays; import java.util.List; +import java.io.File; public class NeuralNetwork { - + private static BasicUtils U = new BasicUtils(); + Matrix weights_ih, weights_ho, bias_h, bias_o; double l_rate = 0.01; boolean useMultiThreading = false; @@ -14,7 +19,6 @@ public NeuralNetwork(int i, int h, int o) { bias_h = new Matrix(h, 1); bias_o = new Matrix(o, 1); - } public NeuralNetwork(int i, int h, int o, boolean useMultiThreading) { @@ -25,7 +29,6 @@ public NeuralNetwork(int i, int h, int o, boolean useMultiThreading) { bias_o = new Matrix(o, 1); this.useMultiThreading = useMultiThreading; - } public NeuralNetwork(int i, int h, int o, double l_rate) { @@ -36,7 +39,6 @@ public NeuralNetwork(int i, int h, int o, double l_rate) { bias_o = new Matrix(o, 1); this.l_rate = l_rate; - } public NeuralNetwork(int i, int h, int o, double l_rate, boolean useMultiThreading) { @@ -48,7 +50,6 @@ public NeuralNetwork(int i, int h, int o, double l_rate, boolean useMultiThreadi this.l_rate = l_rate; this.useMultiThreading = useMultiThreading; - } public List predict(double[] X) { @@ -75,7 +76,7 @@ public void fit(double[][] X, double[][] Y, int epochs, int verbose) { switch (verbose) { case 0: { - System.out.println("Staring training with " + epochs + " epochs"); + U.println("Staring training with " + epochs + " epochs"); long start = System.currentTimeMillis(); for (int i = 0; i < epochs; i++) { int sampleN = (int) (Math.random() * X.length); @@ -83,27 +84,26 @@ public void fit(double[][] X, double[][] Y, int epochs, int verbose) { } long end = System.currentTimeMillis(); long elapsedTime = end - start; - System.out.println("Training took : " + (elapsedTime / 1000) + "s"); + U.println("Training took : " + (elapsedTime / 1000) + "s\n"); break; } case 1: { - System.out.println("Staring training with " + epochs + " epochs"); + U.println("Staring training with " + epochs + " epochs"); long start = System.currentTimeMillis(); for (int i = 0; i < epochs; i++) { - System.out.println("Epoch: " + (i + 1)); + U.println("Epoch: " + (i + 1)); int sampleN = (int) (Math.random() * X.length); this.train(X[sampleN], Y[sampleN], true); } long end = System.currentTimeMillis(); long elapsedTime = end - start; - System.out.println("Training took : " + (elapsedTime / 1000) + "s"); + U.println("Training took : " + (elapsedTime / 1000) + "s"); break; } } - } public void train(double[] X, double[] Y, Boolean showLoss) { @@ -144,7 +144,6 @@ public void train(double[] X, double[] Y, Boolean showLoss) { weights_ih.add(wih_delta); bias_h.add(h_gradient); - } private void printLoss(Matrix error) { @@ -156,12 +155,54 @@ private void printLoss(Matrix error) { } } - System.out.print("Average Error: " + avg + "\n"); + U.print("Average Error: " + avg + "\n"); + } + + public void testData(double[][] testing_in, double[][] testing_out) { + int i = 0; + List output; + for (double d[] : testing_in) { + output = predict(d); + U.printf("Q: %s\nIn: %s\nOut: %s\nPrediction: %s\n",i,U.arrStr(d),U.arrStr(testing_out[i++]),output.toString()); + } + } + + public double[][] getFileData(String fileName) { + List rawData = U.readFile(fileName); + List filteredData = new ArrayList(); + + for (String n : rawData) { + String[] s = n.split(","); + List temp = new ArrayList(); + for (String i : s) { + temp.add(Double.parseDouble(i)); + } + filteredData.add((ArrayList) temp); + } + + double[][] dataArr = nestedListToNestedArr(filteredData); + + return dataArr; + } + + private static double[][] nestedListToNestedArr(List dataList) { + double[][] dataArr = new double[dataList.size()][dataList.get(0).size()]; + int y = 0; + for (List n : dataList) { + int x = 0; + for (Object i : n) { + dataArr[y][x] = ((Double) i); + x++; + } + y++; + } + return dataArr; } - } class Matrix { + private static BasicUtils U = new BasicUtils(); + double[][] data; int rows, cols; @@ -179,9 +220,9 @@ public Matrix(int rows, int cols) { public void print() { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - System.out.print(this.data[i][j] + " "); + U.print(this.data[i][j] + " "); } - System.out.println(); + U.println(""); } } @@ -196,7 +237,7 @@ public void add(int scaler) { public void add(Matrix m) { if (cols != m.cols || rows != m.rows) { - System.out.println("Shape Mismatch"); + U.println("Shape Mismatch"); return; } @@ -212,7 +253,6 @@ public static Matrix fromArray(double[] x) { for (int i = 0; i < x.length; i++) temp.data[i][0] = x[i]; return temp; - } public List toArray() { @@ -285,7 +325,6 @@ public void multiply(double a) { this.data[i][j] *= a; } } - } public void sigmoid() { @@ -293,7 +332,6 @@ public void sigmoid() { for (int j = 0; j < cols; j++) this.data[i][j] = 1 / (1 + Math.exp(-this.data[i][j])); } - } public Matrix dsigmoid() { @@ -303,7 +341,6 @@ public Matrix dsigmoid() { temp.data[i][j] = this.data[i][j] * (1 - this.data[i][j]); } return temp; - } } @@ -357,7 +394,6 @@ public void run() { } class RowMultiplyWorker implements Runnable { - private final Matrix result; private Matrix matrix1; private Matrix matrix2; @@ -377,11 +413,57 @@ public void run() { result.data[row][i] = 0; for (int j = 0; j < matrix1.data[row].length; j++) { result.data[row][i] += matrix1.data[row][j] * matrix2.data[j][i]; - } - } - } +} -} \ No newline at end of file +final class BasicUtils { + private static PrintWriter writer = new PrintWriter(System.out); + + public static List readFile(String fileName) { + List out = new ArrayList(); + try { + File file = new File(fileName); + Scanner reader = new Scanner(file); + while (reader.hasNextLine()) { + String data = reader.nextLine(); + out.add(data); + } + reader.close(); + } catch (FileNotFoundException e) { + println("An error occurred."); + e.printStackTrace(); + } + + return out; + } + + public static String arrStr(double[] arr) { + return Arrays.toString(arr); + } + + public static void print(Object str) { + writer.write(String.valueOf(str)); + writer.flush(); + } + + public static void println(Object str) { + print(str+"\n"); + } + + public static void printf(Object mainText, Object... x) { + String toPrint = ""; + + String[] y = String.valueOf(mainText).split("%s"); + int i = 0; + for (String n : y) { + toPrint += n; + if (i