Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 138 additions & 35 deletions src/test/java/com/library/numj/NumJTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@

import static org.junit.jupiter.api.Assertions.*;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

Expand Down Expand Up @@ -76,9 +74,9 @@ <T> void testArrayCreationWithSingleValue(T value) throws Exception {
/**
* Tests array creation with a given shape and data type.
*
* @param value The initial value of the array.
* @param value The initial value of the array.
* @param expectedShape The expected shape of the resulting array.
* @param dType The data type of the array elements.
* @param dType The data type of the array elements.
*/
@ParameterizedTest
@MethodSource("provideArrayCreationData")
Expand All @@ -98,44 +96,56 @@ <T> void testArrayCreationWithShape(T value, int expectedShape, DType dType) thr
* Tests zeros array creation with different shapes and data types.
*
* @param zeroValue The zero value for the array.
* @param shape The shape of the resulting array.
* @param dType The data type of the array elements.
* @param shape The shape of the resulting array.
* @param dType The data type of the array elements.
*/
@ParameterizedTest
@MethodSource("provideDataForZerosAndOnes")
@MethodSource("provideDataForValidZeros")
<T> void testZerosCreation(T zeroValue, int[] shape, DType dType) throws Exception {
NumJ numJ = new NumJ();
NDArray<T> result = numJ.zeros(shape, dType);
assertNotNull(result);
assertEquals(2, result.ndim());
assertArrayEquals((Object[]) Collections.nCopies(shape[0] * shape[1], zeroValue).toArray(), (Object[]) result.flatten().getArray());
T[][] expectedArray = (T[][]) new Object[shape[0]][shape[1]];
for (int i = 0; i < shape[0]; i++) {
Arrays.fill(expectedArray[i], zeroValue);
}
System.out.println("Expected-------->" + Arrays.deepToString(expectedArray));
result.printArray();
assertArrayEquals(expectedArray, (T[][]) result.getArray());
//assertArrayEquals((Object[]) Collections.nCopies(shape[0] * shape[1], zeroValue).toArray(), (Object[]) result.flatten().getArray());
}


/**
* Tests ones array creation with different shapes and data types.
*
* @param oneValue The one value for the array.
* @param shape The shape of the resulting array.
* @param dType The data type of the array elements.
* @param shape The shape of the resulting array.
* @param dType The data type of the array elements.
*/

@ParameterizedTest
@MethodSource("provideDataForZerosAndOnes")
@MethodSource("provideDataForOnes")
<T> void testOnesCreation(T oneValue, int[] shape, DType dType) throws Exception {
NumJ numJ = new NumJ();
NDArray<T> result = numJ.ones(shape, dType);
assertNotNull(result);
assertEquals(2, result.ndim());
assertArrayEquals((Object[]) java.util.Collections.nCopies(shape[0] * shape[1], oneValue).toArray(), (Object[]) result.flatten().getArray());
T[][] expectedArray = (T[][]) new Object[shape[0]][shape[1]];
for (int i = 0; i < shape[0]; i++) {
Arrays.fill(expectedArray[i], oneValue);
}
System.out.println("Expected-------->" + Arrays.deepToString(expectedArray));
result.printArray();
assertArrayEquals(expectedArray, (T[][]) result.getArray());
}


/**
* Tests arange function with different parameters.
*
* @param start The starting value of the sequence.
* @param end The ending value of the sequence.
* @param start The starting value of the sequence.
* @param end The ending value of the sequence.
* @param expectedSize The expected size of the resulting array.
*/
@ParameterizedTest
Expand All @@ -158,7 +168,7 @@ void testArange(int start, int end, int expectedSize) throws Exception {
* Generates expected data for arange tests.
*
* @param start The starting value of the sequence.
* @param end The ending value of the sequence.
* @param end The ending value of the sequence.
* @return A list of integers representing the expected arange result.
*/
private List<Integer> generateExpectedArangeData(int start, int end) {
Expand All @@ -183,17 +193,11 @@ void testArangeWithShapeAndSkip() throws Exception {
assertArrayEquals(expectedData, result.getArray());
}


/*private void assertArrayCreation(NDArray<Integer> result, Integer[] expected) {
assertEquals(expected.length, result.shape()[0]);
assertArrayEquals(expected, result.data());
}*/

/**
* Tests addition operation between two arrays.
*
* @param data1 The first array for addition.
* @param data2 The second array for addition.
* @param data1 The first array for addition.
* @param data2 The second array for addition.
* @param expected The expected result of the addition.
*/
@ParameterizedTest
Expand All @@ -210,8 +214,8 @@ <T> void testAddition(T data1, T data2, T expected) throws Exception {
/**
* Tests addition operation between two arrays.
*
* @param data1 The first array for Subtraction.
* @param data2 The second array for Subtraction.
* @param data1 The first array for Subtraction.
* @param data2 The second array for Subtraction.
* @param expected The expected result of the Subtraction.
*/
@ParameterizedTest
Expand All @@ -224,11 +228,12 @@ <T> void testSubtraction(T data1, T data2, T expected) throws Exception {
NDArray<T> result = numJ.subtract(arr1, arr2);
assertArrayEquals((Object[]) expected, (Object[]) result.getArray());
}

/**
* Tests addition operation between two arrays.
*
* @param data1 The first array for Multiplication.
* @param data2 The second array for Multiplication.
* @param data1 The first array for Multiplication.
* @param data2 The second array for Multiplication.
* @param expected The expected result of the Multiplication.
*/

Expand All @@ -248,8 +253,8 @@ <T> void testMultiplication(T data1, T data2, T expected) throws Exception {
/**
* Tests addition operation between two arrays.
*
* @param data1 The first array for division.
* @param data2 The second array for division.
* @param data1 The first array for division.
* @param data2 The second array for division.
* @param expected The expected result of the division.
*/
@ParameterizedTest
Expand All @@ -266,12 +271,12 @@ <T> void testDivision(T data1, T data2, T expected) throws Exception {
/**
* Tests the transpose operation on an array.
*
* @param data The input array to be transposed.
* @param data The input array to be transposed.
* @param expected The expected result after transposing.
*/
@ParameterizedTest
@MethodSource("provideDataForTranspose")
<T> void testTranspose(T data , T expected) throws Exception {
<T> void testTranspose(T data, T expected) throws Exception {
NumJ numJ = new NumJ();
NDArray<T> arr = numJ.array(data);
NDArray<T> transposed = numJ.transpose(arr);
Expand Down Expand Up @@ -377,16 +382,73 @@ public void testArangeWithZeroSkip() throws ShapeException {
}


/**
* Tests creation of identity matrix using eye method.
*
* @param size The size of the identity matrix.
* @param expected The expected result of the identity matrix.
*/
@ParameterizedTest
@MethodSource("provideValidEyeData")
<T> void testEyeCreation(int size, T expected) throws Exception {
NDArray<T> result = numJ.eye(size);
result.printArray();
assertNotNull(result);
assertEquals(size, result.shape().get(0));
assertEquals(size, result.shape().get(1));
//assertArrayEquals(expected, (T[]) result.getArray());
}

/**
* Tests that a ShapeException is thrown when creating an eye matrix with invalid dimensions.
*
* @param rows The number of rows.
* @param cols The number of columns.
*/
@ParameterizedTest
@MethodSource("provideInvalidEyeDimensions")
void testEyeWithInvalidDimensions(int rows, int cols) {
assertThrows(ShapeException.class, () -> {
numJ.eye(rows, cols);
}, "Expected ShapeException for invalid dimensions");
}

/**
* Tests the eye function with invalid dimensions, expecting an exception.
*
* @param rows The number of rows in the identity matrix.
* @param cols The number of columns in the identity matrix.
*/
@ParameterizedTest
@MethodSource("provideInvalidEyeDimensions")
void testEyeInvalidDimensions(int rows, int cols) {
assertThrows(IllegalArgumentException.class, () -> {
numJ.eye(rows, cols);
});
}


/**
* Provides data for zeros array creation tests.
*
* @return A stream of arguments for testing zeros creation.
*/
static Stream<Arguments> provideDataForZerosAndOnes() {
static Stream<Arguments> provideDataForValidZeros() {
return Stream.of(
Arguments.of(0, new int[]{2, 2}, DType.INT32),
Arguments.of(1.0, new int[]{2, 2}, DType.FLOAT64)
Arguments.of(0.0, new int[]{2, 2}, DType.FLOAT64),
Arguments.of("0", new int[]{2, 2}, DType.OBJECT)
);
}


static Stream<Arguments> provideDataForOnes() {
return Stream.of(
Arguments.of(1, new int[]{2, 2}, DType.INT32),
Arguments.of(1L, new int[]{2, 2}, DType.INT64),
Arguments.of(1.0, new int[]{2, 2}, DType.FLOAT64),
Arguments.of("1", new int[]{2, 2}, DType.OBJECT)

//Arguments.of("0", new int[]{2, 2}, DType.STR)
);
}
Expand Down Expand Up @@ -452,7 +514,7 @@ static Stream<Arguments> provideArrayCreationData() {
return Stream.of(
Arguments.of(1, 2, DType.INT32),
Arguments.of(2.0, 3, DType.FLOAT64)
// Arguments.of("3", 4, DType.STR)
// Arguments.of("3", 4, DType.STR)
);
}

Expand All @@ -464,4 +526,45 @@ static Stream<Arguments> shapeProvider() {
);
}

/**
* Provides test data for eye function.
*
* @return A Stream of arguments for testing the eye function.
*/
private static Stream<Arguments> provideDataForEye() throws ShapeException {
return Stream.of(
Arguments.of(3, 3, new NDArray<>(new Double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}})),
Arguments.of(2, 4, new NDArray<>(new Double[][]{{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}}))
);
}

/**
* Provides valid test data for the eye function.
*
* @return A Stream of arguments for testing the eye function with valid data.
*/
private static Stream<Arguments> provideValidEyeData() throws ShapeException {
return Stream.of(
Arguments.of(3, 3, new NDArray<>(new Double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}})),
Arguments.of(2, 4, new NDArray<>(new Double[][]{{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}})),
Arguments.of(1, 1, new NDArray<>(new Double[][]{{1.0}})), // Test case for a 1x1 identity matrix
Arguments.of(5, 5, new NDArray<>(new Double[][]{{1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0}})) // Test case for a 5x5 identity matrix
);
}

/**
* Provides invalid test data for the eye function.
*
* @return A Stream of arguments for testing the eye function with invalid dimensions.
*/
private static Stream<Arguments> provideInvalidEyeDimensions() {
return Stream.of(
Arguments.of(-1, 3), // Negative rows
Arguments.of(3, -1), // Negative columns
Arguments.of(0, 3), // Zero rows
Arguments.of(3, 0) // Zero columns
);
}


}