Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions Driver.java

This file was deleted.

16 changes: 16 additions & 0 deletions Neural_Network/Main.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class Main {
public static void main(String [] args) {
NeuralNetwork nn = new NeuralNetwork(2, 10, 1);

double[][] training_data_in = nn.getFileData("training_data_in.txt");
double[][] training_data_out = nn.getFileData("training_data_out.txt");

double[][] testing_data_in = nn.getFileData("testing_data_in.txt");
double[][] testing_data_out = nn.getFileData("testing_data_out.txt");

int epochs = 50000;
nn.fit(training_data_in,training_data_out,epochs,0);

nn.testData(training_data_in,training_data_out);
}
}
136 changes: 109 additions & 27 deletions NeuralNetwork.java → Neural_Network/NeuralNetwork.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@

import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.io.PrintWriter;
import java.util.Scanner;
import java.util.Arrays;
import java.util.List;
import java.io.File;

public class NeuralNetwork {

private static BasicUtils U = new BasicUtils();

Matrix weights_ih, weights_ho, bias_h, bias_o;
double l_rate = 0.01;
boolean useMultiThreading = false;
Expand All @@ -14,7 +19,6 @@ public NeuralNetwork(int i, int h, int o) {

bias_h = new Matrix(h, 1);
bias_o = new Matrix(o, 1);

}

public NeuralNetwork(int i, int h, int o, boolean useMultiThreading) {
Expand All @@ -25,7 +29,6 @@ public NeuralNetwork(int i, int h, int o, boolean useMultiThreading) {
bias_o = new Matrix(o, 1);

this.useMultiThreading = useMultiThreading;

}

public NeuralNetwork(int i, int h, int o, double l_rate) {
Expand All @@ -36,7 +39,6 @@ public NeuralNetwork(int i, int h, int o, double l_rate) {
bias_o = new Matrix(o, 1);

this.l_rate = l_rate;

}

public NeuralNetwork(int i, int h, int o, double l_rate, boolean useMultiThreading) {
Expand All @@ -48,7 +50,6 @@ public NeuralNetwork(int i, int h, int o, double l_rate, boolean useMultiThreadi

this.l_rate = l_rate;
this.useMultiThreading = useMultiThreading;

}

public List<Double> predict(double[] X) {
Expand All @@ -75,35 +76,34 @@ public void fit(double[][] X, double[][] Y, int epochs, int verbose) {
switch (verbose) {

case 0: {
System.out.println("Staring training with " + epochs + " epochs");
U.println("Staring training with " + epochs + " epochs");
long start = System.currentTimeMillis();
for (int i = 0; i < epochs; i++) {
int sampleN = (int) (Math.random() * X.length);
this.train(X[sampleN], Y[sampleN], i + 1 == epochs);
}
long end = System.currentTimeMillis();
long elapsedTime = end - start;
System.out.println("Training took : " + (elapsedTime / 1000) + "s");
U.println("Training took : " + (elapsedTime / 1000) + "s\n");

break;
}

case 1: {
System.out.println("Staring training with " + epochs + " epochs");
U.println("Staring training with " + epochs + " epochs");
long start = System.currentTimeMillis();
for (int i = 0; i < epochs; i++) {
System.out.println("Epoch: " + (i + 1));
U.println("Epoch: " + (i + 1));
int sampleN = (int) (Math.random() * X.length);
this.train(X[sampleN], Y[sampleN], true);
}
long end = System.currentTimeMillis();
long elapsedTime = end - start;
System.out.println("Training took : " + (elapsedTime / 1000) + "s");
U.println("Training took : " + (elapsedTime / 1000) + "s");

break;
}
}

}

public void train(double[] X, double[] Y, Boolean showLoss) {
Expand Down Expand Up @@ -144,7 +144,6 @@ public void train(double[] X, double[] Y, Boolean showLoss) {

weights_ih.add(wih_delta);
bias_h.add(h_gradient);

}

private void printLoss(Matrix error) {
Expand All @@ -156,12 +155,54 @@ private void printLoss(Matrix error) {
}
}

System.out.print("Average Error: " + avg + "\n");
U.print("Average Error: " + avg + "\n");
}

public void testData(double[][] testing_in, double[][] testing_out) {
int i = 0;
List<Double> output;
for (double d[] : testing_in) {
output = predict(d);
U.printf("Q: %s\nIn: %s\nOut: %s\nPrediction: %s\n",i,U.arrStr(d),U.arrStr(testing_out[i++]),output.toString());
}
}

public double[][] getFileData(String fileName) {
List<String> rawData = U.readFile(fileName);
List<ArrayList> filteredData = new ArrayList<ArrayList>();

for (String n : rawData) {
String[] s = n.split(",");
List<Double> temp = new ArrayList<Double>();
for (String i : s) {
temp.add(Double.parseDouble(i));
}
filteredData.add((ArrayList<Double>) temp);
}

double[][] dataArr = nestedListToNestedArr(filteredData);

return dataArr;
}

private static double[][] nestedListToNestedArr(List<ArrayList> dataList) {
double[][] dataArr = new double[dataList.size()][dataList.get(0).size()];
int y = 0;
for (List<ArrayList> n : dataList) {
int x = 0;
for (Object i : n) {
dataArr[y][x] = ((Double) i);
x++;
}
y++;
}
return dataArr;
}

}

class Matrix {
private static BasicUtils U = new BasicUtils();

double[][] data;
int rows, cols;

Expand All @@ -179,9 +220,9 @@ public Matrix(int rows, int cols) {
public void print() {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
System.out.print(this.data[i][j] + " ");
U.print(this.data[i][j] + " ");
}
System.out.println();
U.println("");
}
}

Expand All @@ -196,7 +237,7 @@ public void add(int scaler) {

public void add(Matrix m) {
if (cols != m.cols || rows != m.rows) {
System.out.println("Shape Mismatch");
U.println("Shape Mismatch");
return;
}

Expand All @@ -212,7 +253,6 @@ public static Matrix fromArray(double[] x) {
for (int i = 0; i < x.length; i++)
temp.data[i][0] = x[i];
return temp;

}

public List<Double> toArray() {
Expand Down Expand Up @@ -285,15 +325,13 @@ public void multiply(double a) {
this.data[i][j] *= a;
}
}

}

public void sigmoid() {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++)
this.data[i][j] = 1 / (1 + Math.exp(-this.data[i][j]));
}

}

public Matrix dsigmoid() {
Expand All @@ -303,7 +341,6 @@ public Matrix dsigmoid() {
temp.data[i][j] = this.data[i][j] * (1 - this.data[i][j]);
}
return temp;

}
}

Expand Down Expand Up @@ -357,7 +394,6 @@ public void run() {
}

class RowMultiplyWorker implements Runnable {

private final Matrix result;
private Matrix matrix1;
private Matrix matrix2;
Expand All @@ -377,11 +413,57 @@ public void run() {
result.data[row][i] = 0;
for (int j = 0; j < matrix1.data[row].length; j++) {
result.data[row][i] += matrix1.data[row][j] * matrix2.data[j][i];

}

}

}
}

}
final class BasicUtils {
private static PrintWriter writer = new PrintWriter(System.out);

public static List<String> readFile(String fileName) {
List<String> out = new ArrayList<String>();
try {
File file = new File(fileName);
Scanner reader = new Scanner(file);
while (reader.hasNextLine()) {
String data = reader.nextLine();
out.add(data);
}
reader.close();
} catch (FileNotFoundException e) {
println("An error occurred.");
e.printStackTrace();
}

return out;
}

public static String arrStr(double[] arr) {
return Arrays.toString(arr);
}

public static void print(Object str) {
writer.write(String.valueOf(str));
writer.flush();
}

public static void println(Object str) {
print(str+"\n");
}

public static void printf(Object mainText, Object... x) {
String toPrint = "";

String[] y = String.valueOf(mainText).split("%s");
int i = 0;
for (String n : y) {
toPrint += n;
if (i<x.length) {
toPrint += x[i++];
}
}

println(toPrint);
}
}
4 changes: 4 additions & 0 deletions Neural_Network/testing_data_in.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0,0
0,1
1,0
1,1
4 changes: 4 additions & 0 deletions Neural_Network/testing_data_out.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0
1
1
0
4 changes: 4 additions & 0 deletions Neural_Network/training_data_in.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0,0
1,0
0,1
1,1
4 changes: 4 additions & 0 deletions Neural_Network/training_data_out.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0
1
1
0