diff --git a/AGENTS.md b/AGENTS.md index 70fecf39..1b734e47 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,6 +48,7 @@ When adding or changing user-facing behavior: - happy path - 1–2 edge cases - regression test for the bug/feature request (if applicable) + - avoid using finnts::: or finnts:: to call internal/exported functions; instead, call them directly (e.g. use `my_function()` instead of `finnts:::my_function()`). 4. **Docs**: - update roxygen comments (`@param`, `@return`, `@examples`) - run `devtools::document()` so `man/*.Rd` stays in sync diff --git a/DESCRIPTION b/DESCRIPTION index 58209421..2e34e987 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: finnts Title: Microsoft Finance Time Series Forecasting Framework -Version: 0.6.0.9022 +Version: 0.6.0.9023 Authors@R: c(person(given = "Mike", family = "Tokic", @@ -43,6 +43,7 @@ Imports: glue, glmnet, gtools, + hardhat, hts, kernlab, lubridate, diff --git a/NEWS.md b/NEWS.md index 8aa222f2..5db20031 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,4 @@ -# finnts 0.6.0.9022 (development version) +# finnts 0.6.0.9023 (development version) ## Improvements - TimeGPT Integration diff --git a/R/ensemble_models.R b/R/ensemble_models.R index 56c7777e..2ebb79db 100644 --- a/R/ensemble_models.R +++ b/R/ensemble_models.R @@ -293,7 +293,7 @@ ensemble_models <- function(run_info, set.seed(seed) - grid <- dials::grid_latin_hypercube(parameters, size = num_hyperparameters) + grid <- dials::grid_space_filling(parameters, size = num_hyperparameters) hyperparameters_temp <- grid %>% dplyr::group_split(dplyr::row_number(), .keep = FALSE) %>% diff --git a/R/models.R b/R/models.R index 9a5afeec..45d1042c 100644 --- a/R/models.R +++ b/R/models.R @@ -481,15 +481,15 @@ get_kfold_tune_grid <- function(train_data, ) } -#' Get grid_latin_hypercube +#' Get space-filling grid #' #' @param model_spec Model Spec Obj #' -#' @return gives the latin hypercube grid +#' @return gives a space-filling parameter grid #' @noRd -get_latin_hypercube_grid <- function(model_spec) { - dials::grid_latin_hypercube( - dials::parameters(model_spec), +get_space_filling_grid <- function(model_spec) { + dials::grid_space_filling( + hardhat::extract_parameter_set_dials(model_spec), size = 10 ) } diff --git a/R/parallel_util.R b/R/parallel_util.R index ff5e47d0..f47e3694 100644 --- a/R/parallel_util.R +++ b/R/parallel_util.R @@ -54,8 +54,8 @@ par_start <- function(run_info, add_packages <- c(add_packages, "arrow") } - if (run_info$object_output == "qs") { - add_packages <- c(add_packages, "qs") + if (run_info$object_output == "qs2") { + add_packages <- c(add_packages, "qs2") } # register cluster diff --git a/R/timegpt_model.R b/R/timegpt_model.R index c1b01204..c9b85111 100644 --- a/R/timegpt_model.R +++ b/R/timegpt_model.R @@ -320,16 +320,22 @@ pad_time_series_data <- function(train_df, date_type, min_size = NULL) { } # For each combo, calculate how far back we need to go to add required rows - combos_to_pad <- combos_to_pad %>% - dplyr::mutate( - start_date = dplyr::case_when( - date_type == "day" ~ earliest_date - lubridate::days(rows_to_add), - date_type == "week" ~ earliest_date - lubridate::weeks(rows_to_add), - date_type == "month" ~ earliest_date - months(rows_to_add), - date_type == "quarter" ~ earliest_date - months(rows_to_add * 3), - date_type == "year" ~ earliest_date - lubridate::years(rows_to_add) - ) - ) + if (date_type == "day") { + combos_to_pad <- combos_to_pad %>% + dplyr::mutate(start_date = earliest_date - lubridate::days(rows_to_add)) + } else if (date_type == "week") { + combos_to_pad <- combos_to_pad %>% + dplyr::mutate(start_date = earliest_date - lubridate::weeks(rows_to_add)) + } else if (date_type == "month") { + combos_to_pad <- combos_to_pad %>% + dplyr::mutate(start_date = earliest_date - months(rows_to_add)) + } else if (date_type == "quarter") { + combos_to_pad <- combos_to_pad %>% + dplyr::mutate(start_date = earliest_date - months(rows_to_add * 3)) + } else if (date_type == "year") { + combos_to_pad <- combos_to_pad %>% + dplyr::mutate(start_date = earliest_date - lubridate::years(rows_to_add)) + } # Create complete date sequences for each combo create_date_sequence <- function(start, end, by_type) { diff --git a/R/utility.R b/R/utility.R index 0dc27b5d..d077a5c8 100644 --- a/R/utility.R +++ b/R/utility.R @@ -28,7 +28,7 @@ utils::globalVariables(c( "to", "total_rows", "weighted_mape", "Analysis_Type", "Metric", "Value_Numeric", "is_stationary", "outlier_pct", "model_class", "section", "value", "Hierarchy_Level", "Sort_Order", "run_id", "date_type", "file_path", "models_to_run", "underscore_count", - "max_iterations", "run_complete" + "max_iterations", "run_complete", "earliest_date", "rows_to_add" )) #' @importFrom magrittr %>% diff --git a/tests/testthat/test-agent_helpers.R b/tests/testthat/test-agent_helpers.R new file mode 100644 index 00000000..95ca7381 --- /dev/null +++ b/tests/testthat/test-agent_helpers.R @@ -0,0 +1,352 @@ +# tests/testthat/test-agent_helpers.R +# Tests for helper functions in agent_iterate_forecast.R that don't require LLM + +test_that("collapse_or_na returns NA for 'NULL' string", { + result <- collapse_or_na("NULL") + expect_true(is.na(result)) + expect_type(result, "character") +}) + +test_that("collapse_or_na collapses multiple elements with ---", { + result <- collapse_or_na(c("a", "b", "c")) + expect_equal(result, "a---b---c") +}) + +test_that("collapse_or_na returns single element unchanged", { + + result <- collapse_or_na("hello") + expect_equal(result, "hello") +}) + +test_that("collapse_or_na handles empty vector", { + result <- collapse_or_na(character(0)) + expect_equal(result, "") +}) + +test_that("apply_column_types coerces numeric to integer", { + target <- tibble::tibble(x = c(1.0, 2.0, 3.0)) + template <- tibble::tibble(x = 1L) + result <- apply_column_types(target, template) + expect_type(result$x, "integer") +}) + +test_that("apply_column_types coerces character to numeric", { + target <- tibble::tibble(x = c("1", "2", "3")) + template <- tibble::tibble(x = 1.0) + result <- apply_column_types(target, template) + expect_type(result$x, "double") +}) + +test_that("apply_column_types coerces to Date", { + target <- tibble::tibble(x = c("2020-01-01", "2020-02-01")) + template <- tibble::tibble(x = as.Date("2020-01-01")) + result <- apply_column_types(target, template) + expect_s3_class(result$x, "Date") +}) + +test_that("apply_column_types coerces to factor", { + target <- tibble::tibble(x = c("a", "b", "c")) + template <- tibble::tibble(x = factor("a", levels = c("a", "b", "c", "d"))) + result <- apply_column_types(target, template) + expect_s3_class(result$x, "factor") + expect_equal(levels(result$x), c("a", "b", "c", "d")) +}) + +test_that("apply_column_types coerces to logical", { + target <- tibble::tibble(x = c(1, 0, 1)) + template <- tibble::tibble(x = TRUE) + result <- apply_column_types(target, template) + expect_type(result$x, "logical") +}) + +test_that("apply_column_types coerces to character", { + target <- tibble::tibble(x = c(1, 2, 3)) + template <- tibble::tibble(x = "hello") + result <- apply_column_types(target, template) + expect_type(result$x, "character") +}) + +test_that("apply_column_types preserves columns not in template", { + target <- tibble::tibble(x = 1, y = "extra") + template <- tibble::tibble(x = 1L) + result <- apply_column_types(target, template, drop_extra = FALSE) + expect_true("y" %in% colnames(result)) +}) + +test_that("apply_column_types drops extra columns when requested", { + target <- tibble::tibble(x = 1, y = "extra") + template <- tibble::tibble(x = 1L) + result <- apply_column_types(target, template, drop_extra = TRUE) + expect_false("y" %in% colnames(result)) + expect_equal(ncol(result), 1) +}) + +test_that("apply_column_types reorders columns when requested", { + target <- tibble::tibble(z = "z", x = 1, y = 2) + template <- tibble::tibble(x = 1L, y = 2L) + result <- apply_column_types(target, template, reorder = TRUE) + expect_equal(colnames(result)[1], "x") + expect_equal(colnames(result)[2], "y") +}) + +test_that("apply_column_types coerces to POSIXct", { + target <- tibble::tibble(x = c("2020-01-01 12:00:00", "2020-02-01 13:00:00")) + template <- tibble::tibble(x = as.POSIXct("2020-01-01", tz = "UTC")) + result <- apply_column_types(target, template) + expect_s3_class(result$x, "POSIXct") +}) + +test_that("does_param_set_exist finds matching row", { + df <- tibble::tibble( + a = c(1, 2, 3), + b = c("x", "y", "z") + ) + + result <- does_param_set_exist(list(a = 2, b = "y"), df) + expect_true(result) +}) + +test_that("does_param_set_exist returns FALSE for no match", { + df <- tibble::tibble( + a = c(1, 2, 3), + b = c("x", "y", "z") + ) + + result <- does_param_set_exist(list(a = 4, b = "w"), df) + expect_false(result) +}) + +test_that("does_param_set_exist works with subset of columns", { + df <- tibble::tibble( + a = c(1, 2, 3), + b = c("x", "y", "z"), + c = c(10, 20, 30) + ) + + result <- does_param_set_exist(list(a = 1, b = "x"), df) + expect_true(result) +}) + +test_that("extract_json_object parses simple JSON", { + skip_if_not_installed("jsonlite") + raw <- '{"key": "value", "num": 42}' + result <- extract_json_object(raw) + expect_equal(result$key, "value") + expect_equal(result$num, 42) +}) + +test_that("extract_json_object strips code fences", { + skip_if_not_installed("jsonlite") + raw <- '```json\n{"key": "value"}\n```' + result <- extract_json_object(raw) + expect_equal(result$key, "value") +}) + +test_that("extract_json_object handles surrounding text", { + skip_if_not_installed("jsonlite") + raw <- 'Here is the JSON: {"answer": "yes"} and some more text.' + result <- extract_json_object(raw) + expect_equal(result$answer, "yes") +}) + +test_that("extract_json_object errors on non-character input", { + expect_error( + extract_json_object(123), + "must be a single character string" + ) +}) + +test_that("extract_json_object errors on missing JSON", { + expect_error( + extract_json_object("no json here"), + "No JSON object found" + ) +}) + +test_that("extract_json_object errors on invalid JSON", { + expect_error( + extract_json_object('{"key": }'), + "Failed to parse JSON" + ) +}) + +test_that("null_converter returns NULL for 'NULL' string", { + result <- null_converter("NULL") + expect_null(result) +}) + +test_that("null_converter returns input for non-NULL string", { + result <- null_converter("hello") + expect_equal(result, "hello") +}) + +test_that("null_converter returns vector for multi-element input", { + result <- null_converter(c("a", "b")) + expect_equal(result, c("a", "b")) +}) + +# -- clean_markdown tests -- + +test_that("clean_markdown removes bold markers", { + result <- clean_markdown("This is **bold** text") + expect_equal(result, "This is bold text") +}) + +test_that("clean_markdown removes underscore bold markers", { + result <- clean_markdown("This is __bold__ text") + expect_equal(result, "This is bold text") +}) + +test_that("clean_markdown removes italic markers", { + result <- clean_markdown("This is *italic* text") + expect_equal(result, "This is italic text") +}) + +test_that("clean_markdown removes header markers", { + result <- clean_markdown("# Header text") + expect_equal(result, "Header text") +}) + +test_that("clean_markdown removes multi-level headers", { + result <- clean_markdown("### Third level header") + expect_equal(result, "Third level header") +}) + +test_that("clean_markdown removes code blocks", { + result <- clean_markdown("Use `code` in text") + expect_equal(result, "Use code in text") +}) + +test_that("clean_markdown cleans excess whitespace", { + result <- clean_markdown("Text\n\n\n\nMore text") + expect_equal(result, "Text\n\nMore text") +}) + +test_that("clean_markdown returns empty string for empty input", { + result <- clean_markdown("") + expect_equal(result, "") +}) + +test_that("clean_markdown with keep_lists preserves bullet markers", { + text <- "- Item 1\n- Item 2\n* Item 3" + result <- clean_markdown(text, keep_lists = TRUE) + expect_true(grepl("- Item 1", result)) + expect_true(grepl("- Item 2", result)) + expect_true(grepl("\\* Item 3", result)) +}) + +test_that("clean_markdown with keep_lists=FALSE still removes list-like italics", { + text <- "Some *italic* text" + result <- clean_markdown(text, keep_lists = FALSE) + expect_false(grepl("\\*", result)) +}) + +# -- summarize_analysis_results tests -- + +test_that("summarize_analysis_results handles data.frame result", { + results <- list( + step_1 = data.frame(x = 1:3, y = c("a", "b", "c")) + ) + output <- summarize_analysis_results(results, max_rows = 10) + expect_type(output, "character") + expect_true(grepl("3 rows total", output)) +}) + +test_that("summarize_analysis_results handles numeric result", { + results <- list(step_1 = 42.5) + output <- summarize_analysis_results(results) + expect_true(grepl("42.5", output)) +}) + +test_that("summarize_analysis_results handles character result", { + results <- list(step_1 = "some text result") + output <- summarize_analysis_results(results) + expect_true(grepl("some text result", output)) +}) + +test_that("summarize_analysis_results handles list result", { + results <- list(step_1 = list(a = 1, b = "text")) + output <- summarize_analysis_results(results) + expect_type(output, "character") + expect_true(nchar(output) > 0) +}) + +test_that("summarize_analysis_results truncates long data frames", { + results <- list( + step_1 = data.frame(x = 1:100, y = letters[rep(1:26, length.out = 100)]) + ) + output <- summarize_analysis_results(results, max_rows = 5) + expect_true(grepl("100 rows total, showing 5", output)) + expect_true(grepl("more rows omitted", output)) +}) + +test_that("summarize_analysis_results handles multiple steps", { + results <- list( + step_1 = data.frame(x = 1:3), + step_2 = 42, + step_3 = "text" + ) + output <- summarize_analysis_results(results) + expect_true(grepl("1 result", output)) + expect_true(grepl("2 result", output)) + expect_true(grepl("3 result", output)) +}) + +test_that("summarize_analysis_results handles empty results", { + results <- list() + output <- summarize_analysis_results(results) + expect_equal(output, "") +}) + +# -- display_answer tests -- + +test_that("display_answer prints to console", { + expect_output(display_answer("This is a test answer"), "This is a test answer") +}) + +test_that("display_answer cleans markdown in answer", { + expect_output( + display_answer("This is **bold** text"), + "This is bold text" + ) +}) + +# -- sanitize_args tests (from agent_run.R) -- + +test_that("sanitize_args keeps atomic values", { + result <- sanitize_args(list(a = "hello", b = 42, c = TRUE)) + expect_equal(result$a, "hello") + expect_equal(result$b, 42) + expect_equal(result$c, TRUE) +}) + +test_that("sanitize_args replaces non-atomic values with placeholder", { + result <- sanitize_args(list( + simple = "text", + complex = data.frame(x = 1:3) + )) + expect_equal(result$simple, "text") + expect_equal(result$complex, "") +}) + +test_that("sanitize_args handles empty list", { + result <- sanitize_args(list()) + expect_equal(length(result), 0) +}) + +test_that("sanitize_args handles list values", { + result <- sanitize_args(list( + x = 1, + y = list(a = 1, b = 2) + )) + expect_equal(result$x, 1) + expect_type(result$y, "character") + expect_true(grepl("object", result$y)) +}) + +test_that("sanitize_args handles NULL values in list", { + result <- sanitize_args(list(a = NULL, b = "text")) + # NULL is non-atomic, so sanitize_args converts it to a description string + expect_type(result$a, "character") + expect_equal(result$b, "text") +}) diff --git a/tests/testthat/test-agent_summarize_helpers.R b/tests/testthat/test-agent_summarize_helpers.R new file mode 100644 index 00000000..b46b5f01 --- /dev/null +++ b/tests/testthat/test-agent_summarize_helpers.R @@ -0,0 +1,288 @@ +# Tests for helper functions in agent_summarize_models.R + +test_that(".chr1 handles NULL", { + expect_equal(.chr1(NULL), "auto") +}) + +test_that(".chr1 handles missing quosure", { + q <- rlang::quo() + expect_equal(.chr1(q), "auto") +}) + +test_that(".chr1 handles non-missing quosure", { + q <- rlang::quo(my_var) + expect_equal(.chr1(q), "my_var") +}) + +test_that(".chr1 handles language object", { + expr <- quote(x + y) + expect_equal(.chr1(expr), "x + y") +}) + +test_that(".chr1 handles zero-length vector", { + expect_equal(.chr1(character(0)), "auto") +}) + +test_that(".chr1 handles single atomic value", { + expect_equal(.chr1(42), "42") + expect_equal(.chr1("hello"), "hello") + expect_equal(.chr1(TRUE), "TRUE") +}) + +test_that(".chr1 handles multi-element atomic vector", { + expect_equal(.chr1(c(1, 2, 3)), "1,2,3") + expect_equal(.chr1(c("a", "b")), "a,b") +}) + +test_that(".chr1 handles rlang missing", { + expect_equal(.chr1(rlang::missing_arg()), "auto") +}) + +test_that(".kv creates proper tibble", { + result <- .kv("sec", "nm", "val") + expect_s3_class(result, "tbl_df") + expect_equal(nrow(result), 1) + expect_equal(result$section, "sec") + expect_equal(result$name, "nm") + expect_equal(result$value, "val") +}) + +test_that(".kv coerces types to character", { + result <- .kv(1, 2, 3) + expect_equal(result$section, "1") + expect_equal(result$name, "2") + expect_equal(result$value, "3") +}) + +test_that(".unquote strips double quotes", { + expect_equal(.unquote('"hello"'), "hello") +}) + +test_that(".unquote strips single quotes", { + expect_equal(.unquote("'hello'"), "hello") +}) + +test_that(".unquote leaves unquoted strings alone", { + expect_equal(.unquote("hello"), "hello") +}) + +test_that(".unquote handles non-character input", { + expect_equal(.unquote(42), "42") +}) + +test_that(".signif_chr formats numeric values", { + expect_equal(.signif_chr("3.14159", digits = 3), "3.14") + expect_equal(.signif_chr("1234567", digits = 4), "1235000") +}) + +test_that(".signif_chr passes through non-numeric strings", { + expect_equal(.signif_chr("hello"), "hello") + expect_equal(.signif_chr("NA"), "NA") +}) + +test_that(".signif_chr handles Inf and NaN", { + expect_equal(.signif_chr("Inf"), "Inf") + expect_equal(.signif_chr("NaN"), "NaN") +}) + +test_that(".extract_predictors handles try-error", { + result <- .extract_predictors(structure("error", class = "try-error")) + expect_s3_class(result, "tbl_df") + expect_equal(nrow(result), 0) + expect_equal(names(result), c("section", "name", "value")) +}) + +test_that(".extract_predictors handles NULL predictors", { + mold <- list(predictors = NULL) + result <- .extract_predictors(mold) + expect_equal(nrow(result), 0) +}) + +test_that(".extract_predictors handles empty predictors", { + mold <- list(predictors = data.frame()) + result <- .extract_predictors(mold) + expect_equal(nrow(result), 0) +}) + +test_that(".extract_predictors extracts predictor info", { + mold <- list(predictors = data.frame(x1 = 1:3, x2 = c("a", "b", "c"), stringsAsFactors = FALSE)) + result <- .extract_predictors(mold) + expect_equal(nrow(result), 2) + expect_true(all(result$section == "predictor")) + expect_true("x1" %in% result$name) + expect_true("x2" %in% result$name) +}) + +test_that(".extract_outcomes handles try-error", { + result <- .extract_outcomes(structure("error", class = "try-error")) + expect_equal(nrow(result), 0) +}) + +test_that(".extract_outcomes handles NULL outcomes", { + mold <- list(outcomes = NULL) + result <- .extract_outcomes(mold) + expect_equal(nrow(result), 0) +}) + +test_that(".extract_outcomes extracts outcome info", { + mold <- list(outcomes = data.frame(Target = c(1.5, 2.5, 3.5))) + result <- .extract_outcomes(mold) + expect_equal(nrow(result), 1) + expect_equal(result$section, "outcome") + expect_equal(result$name, "Target") +}) + +test_that(".extract_recipe_steps handles non-recipe object", { + result <- .extract_recipe_steps("not_a_recipe") + expect_equal(nrow(result), 0) + expect_equal(names(result), c("section", "name", "value")) +}) + +test_that(".extract_recipe_steps handles recipe with steps", { + rec <- recipes::recipe(Target ~ ., data = data.frame(Target = 1:5, x1 = 1:5)) %>% + recipes::step_zv(recipes::all_predictors()) %>% + recipes::step_normalize(recipes::all_numeric_predictors()) + result <- .extract_recipe_steps(rec) + expect_equal(nrow(result), 2) + expect_true(all(result$section == "recipe_step")) + expect_true(any(grepl("zero variance", result$value, ignore.case = TRUE) | + grepl("step_zv", result$value, ignore.case = TRUE))) +}) + +test_that(".infer_period_from_dates handles try-error", { + result <- .infer_period_from_dates(structure("error", class = "try-error")) + expect_true(is.na(result)) +}) + +test_that(".infer_period_from_dates handles no date columns", { + mold <- list(predictors = data.frame(x = 1:10)) + result <- .infer_period_from_dates(mold) + expect_true(is.na(result)) +}) + +test_that(".infer_period_from_dates infers monthly frequency", { + dates <- seq(as.Date("2020-01-01"), by = "month", length.out = 24) + mold <- list(predictors = data.frame(Date = dates, x = 1:24)) + result <- .infer_period_from_dates(mold) + expect_equal(result, "12") +}) + +test_that(".infer_period_from_dates infers daily frequency", { + dates <- seq(as.Date("2020-01-01"), by = "day", length.out = 100) + mold <- list(predictors = data.frame(Date = dates, x = 1:100)) + result <- .infer_period_from_dates(mold) + expect_equal(result, "7") +}) + +test_that(".infer_period_from_dates infers yearly frequency", { + dates <- seq(as.Date("2000-01-01"), by = "year", length.out = 20) + mold <- list(predictors = data.frame(Date = dates, x = 1:20)) + result <- .infer_period_from_dates(mold) + expect_equal(result, "1") +}) + +test_that(".infer_period_from_dates infers weekly frequency", { + dates <- seq(as.Date("2020-01-01"), by = "week", length.out = 52) + mold <- list(predictors = data.frame(Date = dates, x = 1:52)) + result <- .infer_period_from_dates(mold) + expect_equal(result, "52") +}) + +test_that(".infer_period_from_dates infers quarterly frequency", { + dates <- seq(as.Date("2010-01-01"), by = "quarter", length.out = 20) + mold <- list(predictors = data.frame(Date = dates, x = 1:20)) + result <- .infer_period_from_dates(mold) + expect_equal(result, "4") +}) + +test_that(".find_obj returns NULL when depth exhausted", { + result <- .find_obj(list(a = 1), function(x) FALSE, depth = 0) + expect_null(result) +}) + +test_that(".find_obj returns NULL for NULL input", { + result <- .find_obj(NULL, function(x) TRUE) + expect_null(result) +}) + +test_that(".find_obj finds top-level match", { + result <- .find_obj(42, function(x) is.numeric(x) && x == 42) + expect_equal(result, 42) +}) + +test_that(".find_obj finds nested list element", { + obj <- list(a = list(b = list(c = "found_it"))) + result <- .find_obj(obj, function(x) identical(x, "found_it")) + expect_equal(result, "found_it") +}) + +test_that(".find_obj finds in environment", { + e <- new.env(parent = emptyenv()) + e$target <- "found" + result <- .find_obj(e, function(x) identical(x, "found")) + expect_equal(result, "found") +}) + +test_that(".find_obj returns NULL when not found", { + obj <- list(a = 1, b = 2, c = 3) + result <- .find_obj(obj, function(x) identical(x, "missing")) + expect_null(result) +}) + +test_that(".assemble_output builds proper tibble", { + preds <- tibble::tibble(section = "predictor", name = "x1", value = "numeric") + outs <- tibble::tibble(section = "outcome", name = "Target", value = "numeric") + steps <- tibble::tibble(section = "recipe_step", name = "1", value = "Normalize") + args <- tibble::tibble(section = "model_arg", name = "trees", value = "100") + eng <- tibble::tibble(section = "engine_param", name = "nthread", value = "1") + + result <- .assemble_output(preds, outs, steps, args, eng, + model_class = "boost_tree", engine = "xgboost" + ) + expect_s3_class(result, "tbl_df") + expect_true("model_class" %in% names(result)) + expect_true("engine" %in% names(result)) + expect_true(all(result$model_class == "boost_tree")) + expect_true(all(result$engine == "xgboost")) + # verify ordering: predictor < outcome < recipe_step < model_arg < engine_param + section_order <- result$section + pred_idx <- which(section_order == "predictor") + out_idx <- which(section_order == "outcome") + step_idx <- which(section_order == "recipe_step") + arg_idx <- which(section_order == "model_arg") + eng_idx <- which(section_order == "engine_param") + if (length(pred_idx) > 0 && length(out_idx) > 0) { + expect_true(max(pred_idx) < min(out_idx)) + } + if (length(out_idx) > 0 && length(step_idx) > 0) { + expect_true(max(out_idx) < min(step_idx)) + } +}) + +test_that(".assemble_output unquotes values when requested", { + args <- tibble::tibble(section = "model_arg", name = "penalty", value = '"0.01"') + result <- .assemble_output( + tibble::tibble(section = character(), name = character(), value = character()), + tibble::tibble(section = character(), name = character(), value = character()), + tibble::tibble(section = character(), name = character(), value = character()), + args, + tibble::tibble(section = character(), name = character(), value = character()), + model_class = "linear_reg", engine = "glmnet", + unquote_values = TRUE + ) + expect_equal(unname(result$value[result$name == "penalty"]), "0.01") +}) + +test_that(".assemble_output deduplicates rows", { + args1 <- tibble::tibble(section = "model_arg", name = "trees", value = "100") + args2 <- tibble::tibble(section = "model_arg", name = "trees", value = "100") + result <- .assemble_output( + tibble::tibble(section = character(), name = character(), value = character()), + tibble::tibble(section = character(), name = character(), value = character()), + tibble::tibble(section = character(), name = character(), value = character()), + dplyr::bind_rows(args1, args2), + tibble::tibble(section = character(), name = character(), value = character()), + model_class = "boost_tree", engine = "xgboost" + ) + expect_equal(sum(result$name == "trees"), 1) +}) diff --git a/tests/testthat/test-feature_selection.R b/tests/testthat/test-feature_selection.R new file mode 100644 index 00000000..00852c15 --- /dev/null +++ b/tests/testthat/test-feature_selection.R @@ -0,0 +1,184 @@ +# tests/testthat/test-feature_selection.R +# Tests for feature selection functions in feature_selection.R + +# -- target_corr_fn tests -- + +test_that("target_corr_fn returns correlated features", { + skip_if_not_installed("corrr") + set.seed(123) + n <- 50 + x1 <- rnorm(n) + data <- tibble::tibble( + Combo = rep("A", n), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = n), + Target = x1 * 2 + rnorm(n, sd = 0.1), + Feature_A = x1, + Feature_B = rnorm(n) + ) + + result <- target_corr_fn(data, threshold = 0.5) + + expect_true("term" %in% colnames(result)) + expect_true("Target" %in% colnames(result)) + expect_true("Feature_A" %in% result$term) +}) + +test_that("target_corr_fn respects threshold", { + skip_if_not_installed("corrr") + set.seed(42) + n <- 50 + x1 <- rnorm(n) + data <- tibble::tibble( + Combo = rep("A", n), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = n), + Target = x1 * 2 + rnorm(n, sd = 0.1), + Feature_A = x1 + rnorm(n, sd = 0.5), + Feature_B = rnorm(n) + ) + + result_high <- target_corr_fn(data, threshold = 0.9) + result_low <- target_corr_fn(data, threshold = 0.1) + + expect_true(nrow(result_low) >= nrow(result_high)) +}) + +test_that("target_corr_fn handles single feature", { + skip_if_not_installed("corrr") + set.seed(123) + n <- 30 + data <- tibble::tibble( + Combo = rep("A", n), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = n), + Target = rnorm(n), + Feature_A = rnorm(n) + ) + + result <- target_corr_fn(data, threshold = 0.5) + expect_s3_class(result, "tbl_df") +}) + +test_that("target_corr_fn handles no correlated features", { + skip_if_not_installed("corrr") + set.seed(999) + n <- 100 + data <- tibble::tibble( + Combo = rep("A", n), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = n), + Target = rnorm(n), + Feature_A = rnorm(n), + Feature_B = rnorm(n) + ) + + result <- target_corr_fn(data, threshold = 0.99) + expect_s3_class(result, "tbl_df") +}) + +# -- Variable importance function tests -- + +test_that("vip_rf_fn produces variable importance scores", { + skip_if_not_installed("ranger") + skip_if_not_installed("vip") + set.seed(42) + data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "day", length.out = 100), + Target = rnorm(100, mean = 50, sd = 10), + x1 = rnorm(100), + x2 = rnorm(100), + x3 = rnorm(100) + ) + data$Target <- data$Target + 2 * data$x1 + + result <- vip_rf_fn(data, seed = 42) + expect_s3_class(result, "tbl_df") + expect_true("Variable" %in% names(result)) + expect_true("Importance" %in% names(result)) + expect_true(all(result$Importance > 0)) +}) + +test_that("vip_lm_fn produces variable importance scores", { + set.seed(42) + data <- tibble::tibble( + Combo = rep("A", 100), + Date = seq(as.Date("2020-01-01"), by = "day", length.out = 100), + Target = rnorm(100, mean = 50, sd = 10), + x1 = rnorm(100), + x2 = rnorm(100), + x3 = rnorm(100) + ) + data$Target <- data$Target + 3 * data$x1 + + result <- vip_lm_fn(data, seed = 42) + expect_s3_class(result, "tbl_df") + expect_true("Variable" %in% names(result)) + expect_true("Importance" %in% names(result)) + expect_true(all(result$Importance > 0)) +}) + +test_that("vip_cubist_fn produces variable importance scores", { + set.seed(42) + data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "day", length.out = 100), + Target = rnorm(100, mean = 50, sd = 10), + x1 = rnorm(100), + x2 = rnorm(100) + ) + data$Target <- data$Target + 5 * data$x1 + + result <- vip_cubist_fn(data, seed = 42) + expect_s3_class(result, "tbl_df") + expect_true("Variable" %in% names(result)) + expect_true("Importance" %in% names(result)) +}) + +# -- boruta_fn tests -- + +test_that("boruta_fn selects important features", { + set.seed(42) + n <- 100 + x1 <- rnorm(n) + data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "day", length.out = n), + Combo = rep("A", n), + Target = x1 * 5 + rnorm(n, sd = 0.1), + Feature_Strong = x1, + Feature_Noise = rnorm(n) + ) + + result <- boruta_fn(data = data, iterations = 50, seed = 42) + expect_type(result, "character") + expect_true(length(result) > 0) + expect_true("Feature_Strong" %in% result) +}) + +test_that("boruta_fn returns character vector", { + set.seed(123) + n <- 80 + data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "day", length.out = n), + Target = rnorm(n, 50, 10), + x1 = rnorm(n), + x2 = rnorm(n) + ) + data$Target <- data$Target + 3 * data$x1 + + result <- boruta_fn(data = data, iterations = 20, seed = 123) + expect_type(result, "character") +}) + +# -- target_corr_fn edge cases -- + +test_that("target_corr_fn with exact threshold boundary", { + set.seed(42) + n <- 200 + x <- rnorm(n) + data <- tibble::tibble( + Combo = rep("A", n), + Date = seq.Date(as.Date("2020-01-01"), by = "day", length.out = n), + Target = x, + Feature_Perfect = x + ) + + result <- target_corr_fn(data, threshold = 0.99) + expect_s3_class(result, "tbl_df") + expect_true("Feature_Perfect" %in% result$term) +}) diff --git a/tests/testthat/test-final_models_helpers.R b/tests/testthat/test-final_models_helpers.R new file mode 100644 index 00000000..b29ff3ad --- /dev/null +++ b/tests/testthat/test-final_models_helpers.R @@ -0,0 +1,255 @@ +# tests/testthat/test-final_models_helpers.R + +test_that("create_prediction_intervals adds interval columns", { + fcst_tbl <- tibble::tibble( + Combo = rep("A", 6), + Model_ID = rep("M1", 6), + Train_Test_ID = c(1, 1, 1, 2, 2, 2), + Target = c(100, 200, 300, 100, 200, 300), + Forecast = c(110, 190, 310, 105, 195, 305), + Horizon = c(1, 2, 3, 1, 2, 3) + ) + + train_test_split <- tibble::tibble( + Run_Type = c("Back_Test", "Future_Forecast"), + Train_Test_ID = c(2, 1) + ) + + result <- create_prediction_intervals(fcst_tbl, train_test_split) + + expect_true("lo_80" %in% colnames(result)) + expect_true("lo_95" %in% colnames(result)) + expect_true("hi_80" %in% colnames(result)) + expect_true("hi_95" %in% colnames(result)) + + # prediction intervals should only be non-NA for Train_Test_ID == 1 + future_rows <- result %>% dplyr::filter(Train_Test_ID == 1) + expect_true(all(!is.na(future_rows$lo_80))) + expect_true(all(!is.na(future_rows$hi_95))) + + back_test_rows <- result %>% dplyr::filter(Train_Test_ID == 2) + expect_true(all(is.na(back_test_rows$lo_80))) +}) + +test_that("create_prediction_intervals calculates intervals correctly", { + fcst_tbl <- tibble::tibble( + Combo = rep("A", 4), + Model_ID = rep("M1", 4), + Train_Test_ID = c(1, 1, 2, 2), + Target = c(100, 200, 100, 200), + Forecast = c(110, 210, 100, 200), + Horizon = c(1, 2, 1, 2) + ) + + train_test_split <- tibble::tibble( + Run_Type = c("Back_Test", "Future_Forecast"), + Train_Test_ID = c(2, 1) + ) + + result <- create_prediction_intervals(fcst_tbl, train_test_split) + + # residuals for back_test: 100-100=0, 200-200=0 -> std_dev = 0 + future_rows <- result %>% dplyr::filter(Train_Test_ID == 1) + expect_equal(future_rows$lo_80, future_rows$Forecast) + expect_equal(future_rows$hi_80, future_rows$Forecast) +}) + +test_that("convert_weekly_to_daily returns unchanged for non-weekly data", { + fcst_tbl <- tibble::tibble( + Combo_ID = "C1", + Model_ID = "M1", + Model_Name = "arima", + Model_Type = "local", + Recipe_ID = "R1", + Train_Test_ID = 1, + Hyperparameter_ID = "H1", + Best_Model = "Yes", + Combo = "A", + Horizon = 1, + Date = as.Date("2020-01-01"), + Target = 100, + Forecast = 110, + lo_95 = 90, + lo_80 = 95, + hi_80 = 125, + hi_95 = 130 + ) + + result <- convert_weekly_to_daily(fcst_tbl, "month", FALSE) + + expect_equal(nrow(result), 1) + expect_true("Date" %in% colnames(result)) + expect_false("Date_Day" %in% colnames(result)) +}) + +test_that("convert_weekly_to_daily returns unchanged when weekly_to_daily is FALSE", { + fcst_tbl <- tibble::tibble( + Combo_ID = "C1", + Model_ID = "M1", + Model_Name = "arima", + Model_Type = "local", + Recipe_ID = "R1", + Train_Test_ID = 1, + Hyperparameter_ID = "H1", + Best_Model = "Yes", + Combo = "A", + Horizon = 1, + Date = as.Date("2020-01-06"), + Target = 700, + Forecast = 700, + lo_95 = 600, + lo_80 = 650, + hi_80 = 750, + hi_95 = 800 + ) + + result <- convert_weekly_to_daily(fcst_tbl, "week", FALSE) + + expect_equal(nrow(result), 1) +}) + +test_that("convert_weekly_to_daily expands weekly to daily", { + fcst_tbl <- tibble::tibble( + Combo_ID = "C1", + Model_ID = "M1", + Model_Name = "arima", + Model_Type = "local", + Recipe_ID = "R1", + Train_Test_ID = 1, + Hyperparameter_ID = "H1", + Best_Model = "Yes", + Combo = "A", + Horizon = 1, + Date = as.Date("2020-01-06"), + Target = 700, + Forecast = 700, + lo_95 = 630, + lo_80 = 700, + hi_80 = 700, + hi_95 = 770 + ) + + result <- convert_weekly_to_daily(fcst_tbl, "week", TRUE) + + expect_equal(nrow(result), 7) + expect_true("Date_Day" %in% colnames(result)) + # daily values should be weekly / 7 + expect_equal(unique(result$Target), 100) + expect_equal(unique(result$Forecast), 100) +}) + +test_that("remove_best_model removes Best_Model column", { + df <- tibble::tibble( + Model_ID = "M1", + Forecast = 100, + Best_Model = "Yes" + ) + + result <- remove_best_model(df) + + expect_false("Best_Model" %in% colnames(result)) + expect_true("Model_ID" %in% colnames(result)) +}) + +test_that("remove_best_model leaves df unchanged without Best_Model", { + df <- tibble::tibble( + Model_ID = "M1", + Forecast = 100 + ) + + result <- remove_best_model(df) + + expect_equal(ncol(result), 2) +}) + +test_that("adjust_combo_column converts logical Combo to character", { + df <- tibble::tibble( + Combo = c(TRUE, FALSE, TRUE), + Forecast = c(100, 200, 300) + ) + + result <- adjust_combo_column(df) + + expect_type(result$Combo, "character") + expect_equal(result$Combo, c("T", "F", "T")) +}) + +test_that("adjust_combo_column leaves character Combo unchanged", { + df <- tibble::tibble( + Combo = c("A", "B", "C"), + Forecast = c(100, 200, 300) + ) + + result <- adjust_combo_column(df) + + expect_equal(result$Combo, c("A", "B", "C")) +}) + +test_that("adjust_combo_column handles missing Combo column", { + df <- tibble::tibble( + Model_ID = "M1", + Forecast = 100 + ) + + result <- adjust_combo_column(df) + + expect_equal(ncol(result), 2) +}) + +# -- create_prediction_intervals with multiple models -- + +test_that("create_prediction_intervals handles multiple Model_IDs", { + fcst_tbl <- tibble::tibble( + Combo = rep("A", 8), + Model_ID = c(rep("M1", 4), rep("M2", 4)), + Train_Test_ID = rep(c(1, 1, 2, 2), 2), + Target = c(100, 200, 100, 200, 100, 200, 100, 200), + Forecast = c(110, 210, 95, 195, 120, 220, 105, 205), + Horizon = rep(c(1, 2, 1, 2), 2) + ) + + train_test_split <- tibble::tibble( + Run_Type = c("Back_Test", "Future_Forecast"), + Train_Test_ID = c(2, 1) + ) + + result <- create_prediction_intervals(fcst_tbl, train_test_split) + + # Should have interval columns for both models + m1_future <- result %>% dplyr::filter(Model_ID == "M1", Train_Test_ID == 1) + m2_future <- result %>% dplyr::filter(Model_ID == "M2", Train_Test_ID == 1) + + expect_true(all(!is.na(m1_future$lo_80))) + expect_true(all(!is.na(m2_future$lo_80))) +}) + +# -- convert_weekly_to_daily with multiple weeks -- + +test_that("convert_weekly_to_daily expands multiple weeks", { + fcst_tbl <- tibble::tibble( + Combo_ID = rep("C1", 2), + Model_ID = rep("M1", 2), + Model_Name = rep("model", 2), + Model_Type = rep("type", 2), + Recipe_ID = rep("R1", 2), + Train_Test_ID = rep(1, 2), + Hyperparameter_ID = rep("H1", 2), + Best_Model = rep("M1", 2), + Combo = rep("C1", 2), + Horizon = c(1, 2), + Date = as.Date(c("2020-01-06", "2020-01-13")), + Target = c(70, 140), + Forecast = c(70, 140), + lo_80 = c(60, 130), + lo_95 = c(50, 120), + hi_80 = c(80, 150), + hi_95 = c(90, 160) + ) + + result <- convert_weekly_to_daily(fcst_tbl, "week", weekly_to_daily = TRUE) + + # Each week should expand to 7 daily rows + expect_equal(nrow(result), 14) + # Daily values should be 1/7 of weekly + expect_equal(result$Forecast[1], 10) +}) diff --git a/tests/testthat/test-hierarchical.R b/tests/testthat/test-hierarchical.R index 61bc0c1b..85f01fd5 100644 --- a/tests/testthat/test-hierarchical.R +++ b/tests/testthat/test-hierarchical.R @@ -1,5 +1,9 @@ +# tests/testthat/test-hierarchical.R +# Tests for hierarchy.R functions + +# -- prep_hierarchical_data tests -- + test_that("prep_hierarchical_data returns correct grouped hierarchies", { - # Mock data setup data <- tibble::tibble( Segment = as.character(c( "Commercial", "Commercial", "Commercial", "Commercial", "Commercial", "Commercial", @@ -38,7 +42,6 @@ test_that("prep_hierarchical_data returns correct grouped hierarchies", { remove = F ) - # run prep hts function result_data <- prep_hierarchical_data( input_data = data, run_info = set_run_info(), @@ -49,7 +52,6 @@ test_that("prep_hierarchical_data returns correct grouped hierarchies", { ) %>% dplyr::filter(Date == "2020-01-01") - # Expected output setup expected_data <- tibble::tibble( Combo = as.character(c( "Total", "Segment_Commercial", "Segment_Consumer", "Country_United_States", "Country_UK", @@ -70,12 +72,10 @@ test_that("prep_hierarchical_data returns correct grouped hierarchies", { Value_Segment_Product = c(1000, 300, 700, 1000, 1000, 400, 600, 100, 200, 100, 200, 300, 400, 300, 400) ) - # Assertions expect_equal(result_data, expected_data) }) test_that("prep_hierarchical_data returns correct standard hierarchies", { - # Mock data setup data <- tibble::tibble( Area = as.character(c("EMEA", "EMEA", "EMEA", "EMEA", "EMEA", "EMEA", "EMEA", "EMEA", "United States", "United States", "United States", "United States")), Country = as.character(c("Croatia", "Croatia", "Croatia", "Croatia", "Greece", "Greece", "Greece", "Greece", "United States", "United States", "United States", "United States")), @@ -91,7 +91,6 @@ test_that("prep_hierarchical_data returns correct standard hierarchies", { remove = F ) - # run prep hts function for standard hierarchy result_data <- prep_hierarchical_data( input_data = data, run_info = set_run_info(), @@ -102,7 +101,6 @@ test_that("prep_hierarchical_data returns correct standard hierarchies", { ) %>% dplyr::filter(Date == "2020-01-01") - # Expected output setup for a standard hierarchical forecast expected_data <- tibble::tibble( Combo = as.character(c("Total", "A", "B", "EMEA_Croatia", "EMEA_Greece", "United_States_United_States")), Date = as.Date(c("2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01")), @@ -112,6 +110,276 @@ test_that("prep_hierarchical_data returns correct standard hierarchies", { Value_Area = c(90, 90, 90, 20, 20, 70) ) - # Assertions expect_equal(result_data, expected_data) }) + +# -- summarize_standard_hierarchy tests -- + +test_that("summarize_standard_hierarchy creates correct summary", { + skip_if_not_installed("hts") + + original_combos <- c("A", "B", "C", "D") + nodes <- list(2, c(2, 2)) + + dummy_data <- matrix(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), nrow = 3, ncol = 4) + colnames(dummy_data) <- original_combos + ts_data <- stats::ts(dummy_data, frequency = 12) + hts_obj <- hts::hts(ts_data, nodes = nodes) %>% suppressMessages() + S <- hts::smatrix(hts_obj) + hts_combos <- paste0("Level_", seq_len(nrow(S))) + + result <- summarize_standard_hierarchy(original_combos, hts_combos, nodes) + + expect_s3_class(result, "data.frame") + expect_true("Hierarchy_Level" %in% names(result)) + expect_true("Level_Type" %in% names(result)) + expect_true("Original_Combos" %in% names(result)) + expect_true("Num_Bottom_Series" %in% names(result)) + expect_true("Total" %in% result$Level_Type) + expect_true("Bottom" %in% result$Level_Type) + total_row <- result[result$Level_Type == "Total", ] + expect_equal(total_row$Num_Bottom_Series, 4) + bottom_rows <- result[result$Level_Type == "Bottom", ] + expect_true(all(bottom_rows$Num_Bottom_Series == 1)) +}) + +test_that("summarize_standard_hierarchy orders by level type", { + skip_if_not_installed("hts") + + original_combos <- c("A", "B", "C", "D") + nodes <- list(2, c(2, 2)) + + dummy_data <- matrix(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), nrow = 3, ncol = 4) + colnames(dummy_data) <- original_combos + ts_data <- stats::ts(dummy_data, frequency = 12) + hts_obj <- hts::hts(ts_data, nodes = nodes) %>% suppressMessages() + S <- hts::smatrix(hts_obj) + hts_combos <- paste0("Node_", seq_len(nrow(S))) + + result <- summarize_standard_hierarchy(original_combos, hts_combos, nodes) + + expect_equal(result$Level_Type[1], "Total") + expect_true(all(which(result$Level_Type == "Bottom") > which(result$Level_Type == "Total"))) +}) + +# -- summarize_grouped_hierarchy tests -- + +test_that("summarize_grouped_hierarchy creates correct summary", { + skip_if_not_installed("hts") + + original_combos <- c("A_X", "A_Y", "B_X", "B_Y") + + nodes <- matrix( + c("A", "A", "B", "B", + "X", "Y", "X", "Y"), + nrow = 2, byrow = TRUE + ) + rownames(nodes) <- c("Group1", "Group2") + + hts_combos <- c( + "Total", + "A", "B", + "X", "Y", + "A_X", "A_Y", "B_X", "B_Y" + ) + + result <- summarize_grouped_hierarchy(original_combos, hts_combos, nodes) + + expect_s3_class(result, "data.frame") + expect_true("Hierarchy_Level" %in% names(result)) + expect_true("Level_Type" %in% names(result)) + expect_true("Total" %in% result$Level_Type) + expect_true("Group1" %in% result$Level_Type) + expect_true("Group2" %in% result$Level_Type) + expect_true("Bottom" %in% result$Level_Type) + + total_row <- result[result$Level_Type == "Total", ] + expect_equal(total_row$Num_Bottom_Series, 4) + + bottom_rows <- result[result$Level_Type == "Bottom", ] + expect_equal(nrow(bottom_rows), 4) + expect_true(all(bottom_rows$Num_Bottom_Series == 1)) +}) + +test_that("summarize_grouped_hierarchy matches combos to groups", { + skip_if_not_installed("hts") + + original_combos <- c("X1", "X2", "Y1") + + nodes <- matrix( + c("X", "X", "Y"), + nrow = 1, byrow = TRUE + ) + rownames(nodes) <- c("Category") + + hts_combos <- c("Total", "X", "Y", "X1", "X2", "Y1") + + result <- summarize_grouped_hierarchy(original_combos, hts_combos, nodes) + + x_row <- result[result$Hierarchy_Level == "X", ] + expect_equal(x_row$Num_Bottom_Series, 2) + + y_row <- result[result$Hierarchy_Level == "Y", ] + expect_equal(y_row$Num_Bottom_Series, 1) +}) + +# -- adjust_df tests -- + +test_that("adjust_df returns data frame as-is for df return_type", { + df <- tibble::tibble( + Combo = rep("A", 5), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 5), + Target = c(1, 2, 3, 4, 5) + ) + + result <- adjust_df(df, return_type = "df") + expect_s3_class(result, "tbl_df") + expect_equal(nrow(result), 5) + expect_equal(result$Target, df$Target) +}) + +# -- get_hts tests -- + +test_that("get_hts creates hts object for standard hierarchy", { + skip_if_not_installed("hts") + + mat <- matrix(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), nrow = 5, ncol = 2) + colnames(mat) <- c("A", "B") + ts_data <- stats::ts(mat, start = c(2020, 1), frequency = 12) + + result <- get_hts(ts_data, nodes = list(2), forecast_approach = "standard_hierarchy") + expect_true(inherits(result, "hts")) +}) + +test_that("get_hts creates gts object for grouped hierarchy", { + skip_if_not_installed("hts") + + mat <- matrix(1:15, nrow = 5, ncol = 3) + colnames(mat) <- c("A_X", "A_Y", "B_X") + ts_data <- stats::ts(mat, start = c(2020, 1), frequency = 12) + + groups <- matrix( + c("A", "A", "B", + "X", "Y", "X"), + nrow = 2, byrow = TRUE + ) + rownames(groups) <- c("Group1", "Group2") + + result <- get_hts(ts_data, nodes = groups, forecast_approach = "grouped_hierarchy") + expect_true(inherits(result, "gts")) +}) + +# -- get_hts_nodes tests -- + +test_that("get_hts_nodes returns nodes for standard hierarchy", { + skip_if_not_installed("hts") + + mat <- matrix(1:10, nrow = 5, ncol = 2) + colnames(mat) <- c("A", "B") + ts_data <- stats::ts(mat, start = c(2020, 1), frequency = 12) + hts_obj <- hts::hts(ts_data, nodes = list(2)) %>% suppressMessages() + + result <- get_hts_nodes(hts_obj, forecast_approach = "standard_hierarchy") + expect_true(is.list(result)) +}) + +# -- get_grouped_nodes tests -- + +test_that("get_grouped_nodes creates grouping matrix", { + input_data <- data.frame( + Segment = c("A", "A", "B", "B"), + Country = c("US", "UK", "US", "UK") + ) + + result <- get_grouped_nodes(input_data, c("Segment", "Country")) + expect_true(is.matrix(result)) + expect_equal(nrow(result), 2) + expect_equal(ncol(result), 4) + expect_equal(rownames(result), c("Segment", "Country")) +}) + +test_that("get_grouped_nodes with single variable", { + input_data <- data.frame( + Region = c("East", "West", "East", "West") + ) + + result <- get_grouped_nodes(input_data, "Region") + expect_true(is.matrix(result)) + expect_equal(nrow(result), 1) +}) + +# -- get_standard_nodes tests -- + +test_that("get_standard_nodes creates node list", { + input_data <- data.frame( + Area = c("A", "A", "B"), + Country = c("X", "Y", "Z") + ) + + result <- get_standard_nodes(input_data, c("Area", "Country")) + expect_true(is.list(result)) + expect_true(length(result) >= 1) +}) + +test_that("get_standard_nodes with 3 levels", { + input_data <- data.frame( + Region = c("NA", "NA", "NA", "EU", "EU", "EU"), + Country = c("US", "US", "CA", "UK", "UK", "FR"), + City = c("NY", "LA", "TO", "LO", "MA", "PA") + ) + + result <- get_standard_nodes(input_data, c("Region", "Country", "City")) + expect_true(is.list(result)) +}) + +# -- adjust_df tests -- + +test_that("adjust_df with df return_type collects data", { + input_data <- tibble::tibble(x = 1:5, y = letters[1:5]) + result <- adjust_df(input_data, return_type = "df") + expect_s3_class(result, "tbl_df") + expect_equal(nrow(result), 5) +}) + +# -- pick_right_hierarchy tests -- + +test_that("pick_right_hierarchy routes to grouped_nodes", { + input_data <- data.frame( + Segment = c("A", "A", "B", "B"), + Country = c("US", "UK", "US", "UK") + ) + + result <- pick_right_hierarchy(input_data, c("Segment", "Country"), "grouped_hierarchy") + expect_true(is.matrix(result)) +}) + +test_that("pick_right_hierarchy routes to standard_nodes", { + input_data <- data.frame( + Area = c("A", "A", "B"), + Country = c("X", "Y", "Z") + ) + + result <- pick_right_hierarchy(input_data, c("Area", "Country"), "standard_hierarchy") + expect_true(is.list(result)) +}) + +# -- get_hts_nodes for grouped hierarchy -- + +test_that("get_hts_nodes returns groups for grouped hierarchy", { + skip_if_not_installed("hts") + + mat <- matrix(1:15, nrow = 5, ncol = 3) + colnames(mat) <- c("AX", "AY", "BX") + ts_data <- stats::ts(mat, start = c(2020, 1), frequency = 12) + + groups <- matrix( + c("A", "A", "B", + "X", "Y", "X"), + nrow = 2, byrow = TRUE + ) + rownames(groups) <- c("Group1", "Group2") + + gts_obj <- hts::gts(ts_data, groups = groups) %>% suppressMessages() + result <- get_hts_nodes(gts_obj, forecast_approach = "grouped_hierarchy") + expect_true(!is.null(result)) +}) diff --git a/tests/testthat/test-input_checks.R b/tests/testthat/test-input_checks.R new file mode 100644 index 00000000..5ca50772 --- /dev/null +++ b/tests/testthat/test-input_checks.R @@ -0,0 +1,349 @@ +# tests/testthat/test-input_checks.R + +test_that("check_input_type validates character type", { + expect_silent(check_input_type("x", "hello", "character")) + expect_error( + check_input_type("x", 123, "character"), + "invalid type for input name 'x'" + ) +}) + +test_that("check_input_type validates numeric type", { + expect_silent(check_input_type("x", 42, "numeric")) + expect_error( + check_input_type("x", "abc", "numeric"), + "invalid type for input name 'x'" + ) +}) + +test_that("check_input_type validates logical type", { + expect_silent(check_input_type("x", TRUE, "logical")) + expect_error( + check_input_type("x", "yes", "logical"), + "invalid type for input name 'x'" + ) +}) + +test_that("check_input_type validates expected values", { + expect_silent(check_input_type("x", "csv", "character", c("csv", "parquet"))) + expect_error( + check_input_type("x", "json", "character", c("csv", "parquet")), + "invalid value for input name 'x'" + ) +}) + +test_that("check_input_type accepts multiple types", { + expect_silent(check_input_type("x", "hello", c("character", "numeric"))) + expect_silent(check_input_type("x", 42, c("character", "numeric"))) +}) + +test_that("check_input_type passes with NULL expected_value", { + expect_silent(check_input_type("x", "anything", "character", NULL)) +}) + +test_that("check_input_data catches missing combo variables", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Target = 100, + id = "A" + ) + + expect_error( + check_input_data( + data, + combo_variables = c("missing_col"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "combo variables do not match" + ) +}) + +test_that("check_input_data catches missing target variable", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Value = 100, + id = "A" + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "target variable does not match" + ) +}) + +test_that("check_input_data catches non-numeric target", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Target = "abc", + id = "A" + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "Target variable in input data needs to be numeric" + ) +}) + +test_that("check_input_data catches missing Date column", { + data <- tibble::tibble( + date_col = as.Date("2020-01-01"), + Target = 100, + id = "A" + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "date column in input data needs to be named as 'Date'" + ) +}) + +test_that("check_input_data catches non-date Date column", { + data <- tibble::tibble( + Date = "2020-01-01", + Target = 100, + id = "A" + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "date column in input data needs to be formatted as a date" + ) +}) + +test_that("check_input_data catches invalid fiscal_year_start", { + data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Target = c(100, 200, 300), + id = c("A", "A", "A") + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 13, + parallel_processing = NULL + ), + "fiscal year start should be a number from 1 to 12" + ) +}) + +test_that("check_input_data catches missing external regressors", { + data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01")), + Target = c(100, 200), + id = c("A", "A") + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = c("missing_xreg"), + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "external regressors do not match" + ) +}) + +test_that("check_input_data catches duplicate rows", { + data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-01-01", "2020-02-01")), + Target = c(100, 100, 200), + id = c("A", "A", "A") + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "duplicate rows have been detected" + ) +}) + +test_that("check_input_data passes with valid data", { + data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Target = c(100, 200, 300), + id = c("A", "A", "A") + ) + + expect_silent( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ) + ) +}) + +test_that("check_parallel_processing passes with NULL", { + run_info <- list(path = "/tmp/test") + + expect_silent( + check_parallel_processing( + run_info, + parallel_processing = NULL + ) + ) +}) + +test_that("check_parallel_processing errors on invalid value", { + run_info <- list(path = "/tmp/test") + + expect_error( + check_parallel_processing( + run_info, + parallel_processing = "invalid" + ), + "parallel processing input must be one of these values" + ) +}) + +test_that("check_parallel_processing errors on local_machine with inner_parallel", { + run_info <- list(path = "/tmp/test") + + expect_error( + check_parallel_processing( + run_info, + parallel_processing = "local_machine", + inner_parallel = TRUE + ), + "cannot run parallel process" + ) +}) + +test_that("check_parallel_processing errors on spark without sc", { + run_info <- list(path = "/dbfs/test") + + expect_error( + check_parallel_processing( + run_info, + parallel_processing = "spark" + ), + "Ensure that you are connected to a spark cluster" + ) +}) + +test_that("check_agent_info validates list type", { + expect_error( + check_agent_info("not_a_list"), + "agent_info should be a list" + ) +}) + +test_that("check_agent_info catches missing elements", { + incomplete_info <- list( + agent_version = "1.0", + run_id = "test" + ) + + expect_error( + check_agent_info(incomplete_info), + "agent_info is missing required elements" + ) +}) + +test_that("check_agent_info passes with all required elements", { + full_info <- list( + agent_version = "1.0", + run_id = "test", + project_info = list(), + driver_llm = "gpt-4", + reason_llm = "gpt-4", + forecast_horizon = 3, + external_regressors = NULL, + hist_end_date = as.Date("2024-01-01"), + back_test_scenarios = 3, + back_test_spacing = 1, + combo_cleanup_date = NULL, + overwrite = FALSE + ) + + expect_silent( + check_agent_info(full_info) + ) +}) + +test_that("check_input_data catches unevenly spaced month dates", { + data <- tibble::tibble( + Date = as.Date(c("2020-01-15", "2020-02-20", "2020-03-01")), + Target = c(100, 200, 300), + id = c("A", "A", "A") + ) + + expect_error( + check_input_data( + data, + combo_variables = c("id"), + target_variable = "Target", + external_regressors = NULL, + date_type = "month", + fiscal_year_start = 1, + parallel_processing = NULL + ), + "historical date values are not evenly spaced" + ) +}) + +# -- check_input_type edge cases -- + +test_that("check_input_type accepts NULL with NULL type", { + expect_silent(check_input_type("x", NULL, c("character", "NULL"))) +}) + +test_that("check_input_type accepts vector of expected values", { + expect_silent(check_input_type("x", c("R1", "R2"), "character", c("R1", "R2", "R3"))) +}) diff --git a/tests/testthat/test-models.R b/tests/testthat/test-models.R new file mode 100644 index 00000000..700ebcd8 --- /dev/null +++ b/tests/testthat/test-models.R @@ -0,0 +1,918 @@ +# tests/testthat/test-models.R +# Tests for models.R, model workflows, and train_models.R helpers + +# -- Model listing functions -- + +test_that("list_models returns expected model list", { + models <- list_models() + + expect_type(models, "character") + expect_true(length(models) == 23) + expect_true("arima" %in% models) + expect_true("ets" %in% models) + expect_true("xgboost" %in% models) + expect_true("cubist" %in% models) + expect_true("glmnet" %in% models) + expect_true("prophet" %in% models) + expect_true("snaive" %in% models) + expect_true("timegpt" %in% models) + expect_true("meanf" %in% models) + expect_true("theta" %in% models) +}) + +test_that("list_hyperparmater_models returns hyperparameter models", { + models <- list_hyperparmater_models() + + expect_type(models, "character") + expect_true(length(models) == 13) + expect_true("xgboost" %in% models) + expect_true("cubist" %in% models) + expect_true("glmnet" %in% models) + expect_true("timegpt" %in% models) + expect_false("arima" %in% models) + expect_false("ets" %in% models) + expect_false("snaive" %in% models) +}) + +test_that("list_ensemble_models returns ensemble models", { + models <- list_ensemble_models() + + expect_type(models, "character") + expect_true(length(models) == 5) + expect_true("cubist" %in% models) + expect_true("glmnet" %in% models) + expect_true("xgboost" %in% models) + expect_true("svm-poly" %in% models) + expect_true("svm-rbf" %in% models) +}) + +test_that("list_r2_models returns R2 recipe models", { + models <- list_r2_models() + + expect_type(models, "character") + expect_true(length(models) == 5) + expect_true("cubist" %in% models) + expect_true("xgboost" %in% models) +}) + +test_that("list_global_models returns global models", { + models <- list_global_models() + + expect_type(models, "character") + expect_true(length(models) == 2) + expect_true("xgboost" %in% models) + expect_true("timegpt" %in% models) +}) + +test_that("list_multivariate_models returns multivariate models", { + models <- list_multivariate_models() + + expect_type(models, "character") + expect_true(length(models) == 12) + expect_true("cubist" %in% models) + expect_true("arimax" %in% models) + expect_true("timegpt" %in% models) +}) + +test_that("list_multistep_models returns multistep models", { + models <- list_multistep_models() + + expect_type(models, "character") + expect_true(length(models) == 6) + expect_true("cubist" %in% models) + expect_true("glmnet" %in% models) + expect_true("mars" %in% models) + expect_true("xgboost" %in% models) + expect_true("svm-poly" %in% models) + expect_true("svm-rbf" %in% models) +}) + +# -- Recipe functions -- + +test_that("get_recipe_simple creates a recipe with Target ~ Date", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300) + ) + + recipe <- get_recipe_simple(train_data) + expect_s3_class(recipe, "recipe") + expect_true("Target" %in% recipe$var_info$variable) + expect_true("Date" %in% recipe$var_info$variable) +}) + +test_that("get_recipe_combo creates a recipe with Target ~ Date + Combo", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300) + ) + + recipe <- get_recipe_combo(train_data) + expect_s3_class(recipe, "recipe") + expect_true("Combo" %in% recipe$var_info$variable) +}) + +test_that("get_recipe_configurable creates a recipe with default settings", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Feature1 = c(1, 2, 3) + ) + + recipe <- get_recipe_configurable(train_data) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles rm_date options", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Date_index.num = c(1, 2, 3) + ) + + recipe_with_adj <- get_recipe_configurable(train_data, rm_date = "with_adj") + expect_s3_class(recipe_with_adj, "recipe") + + recipe_with_adj_index <- get_recipe_configurable(train_data, rm_date = "with_adj_index") + expect_s3_class(recipe_with_adj_index, "recipe") + + recipe_none <- get_recipe_configurable(train_data, rm_date = "none") + expect_s3_class(recipe_none, "recipe") +}) + +test_that("get_recipe_configurable handles center_scale option", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Feature1 = c(1, 2, 3) + ) + + recipe <- get_recipe_configurable(train_data, center_scale = TRUE, pca = FALSE) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles pca = FALSE", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Feature1 = c(1, 2, 3) + ) + + recipe <- get_recipe_configurable(train_data, pca = FALSE) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles step_nzv options", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Feature1 = c(1, 2, 3) + ) + + recipe_zv <- get_recipe_configurable(train_data, step_nzv = "zv") + expect_s3_class(recipe_zv, "recipe") + + recipe_nzv <- get_recipe_configurable(train_data, step_nzv = "nzv") + expect_s3_class(recipe_nzv, "recipe") +}) + +test_that("get_recipe_configurable handles dummy_one_hot FALSE", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Category = c("X", "Y", "Z") + ) + + recipe <- get_recipe_configurable( + train_data, + dummy_one_hot = FALSE, + character_factor = TRUE, + pca = FALSE + ) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles corr and lincomb", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Feature1 = c(1, 2, 3), + Feature2 = c(10, 20, 30) + ) + + recipe <- get_recipe_configurable( + train_data, + corr = TRUE, + lincomb = TRUE, + pca = FALSE + ) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles mutate_adj_half", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Date_half = c(1, 1, 1), + Date_quarter = c(1, 1, 1) + ) + + recipe <- get_recipe_configurable(train_data, mutate_adj_half = TRUE, pca = FALSE) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable handles norm_date_adj_year", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + Date_index.num = c(1, 2, 3), + Date_year = c(2020, 2020, 2020) + ) + + recipe <- get_recipe_configurable(train_data, norm_date_adj_year = TRUE, pca = FALSE) + expect_s3_class(recipe, "recipe") +}) + +test_that("get_recipe_configurable removes _original columns", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300), + temp_original = c(10, 20, 30) + ) + + recipe <- get_recipe_configurable(train_data, pca = FALSE) + expect_s3_class(recipe, "recipe") + step_ids <- sapply(recipe$steps, function(s) s$id) + expect_true("step_remove_original" %in% step_ids) +}) + +# -- Helper to create minimal training data for model workflows -- + +make_train_data <- function(n = 36) { + tibble::tibble( + Combo = rep("test_combo", n), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = n), + Target = sin(seq_len(n) / 6 * pi) * 10 + 50 + rnorm(n, sd = 2) + ) +} + +test_that("get_recipe_configurable with arimax options creates valid recipe", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- get_recipe_configurable( + train_data, + step_nzv = "zv", + dummy_one_hot = TRUE, + corr = TRUE, + pca = FALSE, + lincomb = TRUE + ) + expect_s3_class(result, "recipe") +}) + +# -- Workflow functions -- + +test_that("get_workflow_simple creates valid workflow", { + train_data <- tibble::tibble( + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Combo = c("A", "A", "A"), + Target = c(100, 200, 300) + ) + + recipe <- get_recipe_simple(train_data) + model_spec <- parsnip::linear_reg() %>% parsnip::set_engine("lm") + wflw <- get_workflow_simple(model_spec, recipe) + + expect_s3_class(wflw, "workflow") +}) + +# -- Resampling functions -- + +test_that("get_resample_kfold returns vfold_cv object", { + train_data <- tibble::tibble( + x = rnorm(100), + y = rnorm(100) + ) + + result <- get_resample_kfold(train_data) + expect_s3_class(result, "vfold_cv") +}) + +test_that("get_resample_tscv creates time series CV splits", { + train_data <- make_train_data(60) %>% dplyr::select(-Combo) + result <- get_resample_tscv( + train_data = train_data, + tscv_initial = 24, + horizon = 3, + back_test_spacing = 6 + ) + expect_s3_class(result, "rset") +}) + +test_that("get_space_filling_grid creates parameter grid", { + model_spec <- parsnip::boost_tree( + mode = "regression", + trees = tune::tune(), + tree_depth = tune::tune(), + learn_rate = tune::tune() + ) %>% parsnip::set_engine("xgboost") + + result <- get_space_filling_grid(model_spec) + expect_s3_class(result, "tbl_df") + expect_equal(nrow(result), 10) +}) + +# -- Simple univariate model workflow tests -- + +test_that("arima creates a workflow", { + train_data <- make_train_data() + result <- arima(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("croston creates a workflow", { + train_data <- make_train_data() + result <- croston(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("ets creates a workflow", { + train_data <- make_train_data() + result <- ets(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("theta creates a workflow", { + train_data <- make_train_data() + result <- theta(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("tbats creates a workflow", { + train_data <- make_train_data() + result <- tbats(train_data, seasonal_period = c(12, NA, NA)) + expect_s3_class(result, "workflow") +}) + +test_that("stlm_arima creates a workflow", { + train_data <- make_train_data() + result <- stlm_arima(train_data, seasonal_period = c(12, NA, NA)) + expect_s3_class(result, "workflow") +}) + +test_that("stlm_ets creates a workflow", { + train_data <- make_train_data() + result <- stlm_ets(train_data, seasonal_period = c(12, NA, NA)) + expect_s3_class(result, "workflow") +}) + +test_that("prophet creates a workflow", { + train_data <- make_train_data() + result <- prophet(train_data) + expect_s3_class(result, "workflow") +}) + +test_that("nnetar creates a workflow", { + train_data <- make_train_data() + result <- nnetar(train_data, horizon = 3, frequency = 12) + expect_s3_class(result, "workflow") +}) + +# -- Multivariate model workflow tests (non-multistep) -- + +test_that("arimax creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- arimax(train_data, frequency = 12, pca = FALSE) + expect_s3_class(result, "workflow") +}) + +test_that("arima_boost creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- arima_boost(train_data, frequency = 12, pca = FALSE) + expect_s3_class(result, "workflow") +}) + +# -- ML model workflows (non-multistep) -- + +test_that("glmnet non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- glmnet( + train_data, pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("mars non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- mars( + train_data, pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("cubist non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- cubist( + train_data, pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("svm_poly non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- svm_poly( + train_data, model_type = "single", pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("svm_rbf non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- svm_rbf( + train_data, model_type = "single", pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("xgboost non-multistep creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- xgboost( + train_data, pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +# -- ML model workflows (multistep) -- + +test_that("glmnet multistep creates a workflow", { + train_data <- make_train_data() + result <- glmnet( + train_data, pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("xgboost multistep creates a workflow", { + train_data <- make_train_data() + result <- xgboost( + train_data, pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("cubist multistep creates a workflow", { + train_data <- make_train_data() + result <- cubist( + train_data, pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("mars multistep creates a workflow", { + train_data <- make_train_data() + result <- mars( + train_data, pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("svm_poly multistep creates a workflow", { + train_data <- make_train_data() + result <- svm_poly( + train_data, model_type = "single", pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("svm_rbf multistep creates a workflow", { + train_data <- make_train_data() + result <- svm_rbf( + train_data, model_type = "single", pca = FALSE, multistep = TRUE, + horizon = 3, external_regressors = NULL, frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +# -- Ensemble model types -- + +test_that("svm_poly ensemble creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- svm_poly( + train_data, model_type = "ensemble", pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +test_that("svm_rbf ensemble creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- svm_rbf( + train_data, model_type = "ensemble", pca = FALSE, multistep = FALSE, + horizon = 3, external_regressors = "xreg1", frequency = 12 + ) + expect_s3_class(result, "workflow") +}) + +# -- Fitting and tuning helpers -- + +test_that("get_fit_simple fits a workflow", { + train_data <- make_train_data(60) + wflw <- arima(train_data, frequency = 12) + result <- get_fit_simple(train_data, wflw) + expect_s3_class(result, "workflow") + expect_true(workflows::is_trained_workflow(result)) +}) + +# -- train_models.R helper functions -- + +test_that("negative_fcst_adj replaces NA with zero", { + data <- tibble::tibble(Forecast = c(10, NA, 30)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(result$Forecast, c(10, 0, 30)) +}) + +test_that("negative_fcst_adj replaces NaN with zero", { + data <- tibble::tibble(Forecast = c(10, NaN, 30)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(result$Forecast, c(10, 0, 30)) +}) + +test_that("negative_fcst_adj replaces Inf with zero", { + data <- tibble::tibble(Forecast = c(10, Inf, -Inf)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(result$Forecast, c(10, 0, 0)) +}) + +test_that("negative_fcst_adj keeps negative when negative_forecast=TRUE", { + data <- tibble::tibble(Forecast = c(-10, 20, -30)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(result$Forecast, c(-10, 20, -30)) +}) + +test_that("negative_fcst_adj zeroes negatives when negative_forecast=FALSE", { + data <- tibble::tibble(Forecast = c(-10, 20, -30)) + result <- negative_fcst_adj(data, FALSE) + expect_equal(result$Forecast, c(0, 20, 0)) +}) + +test_that("negative_fcst_adj handles all NA/NaN/Inf", { + data <- tibble::tibble(Forecast = c(NA, NaN, Inf, -Inf)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(result$Forecast, c(0, 0, 0, 0)) +}) + +test_that("negative_fcst_adj handles empty data frame", { + data <- tibble::tibble(Forecast = numeric(0)) + result <- negative_fcst_adj(data, TRUE) + expect_equal(nrow(result), 0) +}) + +test_that("create_splits returns manual_rset object", { + data <- tibble::tibble( + Combo = rep("A", 12), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 12), + Target = 1:12 + ) + + train_test_splits <- tibble::tibble( + Run_Type = c("Validation", "Back_Test"), + Train_Test_ID = c(2, 1), + Train_End = as.Date(c("2020-08-01", "2020-09-01")), + Test_End = as.Date(c("2020-10-01", "2020-12-01")) + ) + + result <- create_splits(data, train_test_splits) + + expect_s3_class(result, "rset") + expect_equal(nrow(result), 2) + expect_true("splits" %in% colnames(result)) + expect_true("id" %in% colnames(result)) +}) + +test_that("create_splits assigns correct IDs", { + data <- tibble::tibble( + Combo = rep("A", 12), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 12), + Target = 1:12 + ) + + train_test_splits <- tibble::tibble( + Run_Type = c("Validation"), + Train_Test_ID = c(3), + Train_End = as.Date(c("2020-08-01")), + Test_End = as.Date(c("2020-12-01")) + ) + + result <- create_splits(data, train_test_splits) + expect_equal(result$id, "3") +}) + +# -- undifference / adjust_column_types helpers -- + +test_that("undifference_forecast undifferences with single diff", { + original_target <- cumsum(c(10, rep(1, 9))) + diff_target <- c(NA, diff(original_target)) + + recipe_data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 7), + Target = diff_target[1:7] + ) + + forecast_data <- tibble::tibble( + Train_Test_ID = rep(1, 3), + Date = seq(as.Date("2020-08-01"), by = "month", length.out = 3), + Target = diff_target[8:10], + Forecast = diff_target[8:10] + 0.1 + ) + + diff_tbl <- tibble::tibble( + Combo = "A", + Diff_Value1 = original_target[1], + Diff_Value2 = NA + ) + + result <- undifference_forecast(forecast_data, recipe_data, diff_tbl) + expect_s3_class(result, "tbl_df") + expect_true("Forecast" %in% names(result)) + expect_true("Target" %in% names(result)) + expect_equal(nrow(result), 3) +}) + +test_that("undifference_forecast returns unchanged when no diffs", { + forecast_data <- tibble::tibble( + Date = as.Date(c("2020-07-01", "2020-08-01")), + Target = c(70, 80), + Forecast = c(72, 78), + Train_Test_ID = c(1, 1) + ) + + recipe_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 6), + Target = c(10, 20, 30, 40, 50, 60) + ) + + diff_tbl <- tibble::tibble( + Diff_Value1 = NA_real_, + Diff_Value2 = NA_real_ + ) + + result <- undifference_forecast(forecast_data, recipe_data, diff_tbl) + expect_equal(result, forecast_data) +}) + +test_that("undifference_recipe undifferences recipe data", { + original_target <- cumsum(c(10, rep(1, 9))) + diff_target <- c(NA, diff(original_target)) + + recipe_data <- tibble::tibble( + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 10), + Target = diff_target + ) + + diff_tbl <- tibble::tibble( + Combo = "A", + Diff_Value1 = original_target[1], + Diff_Value2 = NA + ) + + hist_end_date <- as.Date("2020-07-01") + result <- undifference_recipe(recipe_data, diff_tbl, hist_end_date) + expect_s3_class(result, "tbl_df") + expect_true("Target" %in% names(result)) +}) + +test_that("undifference_recipe returns unchanged when no diffs", { + recipe_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 6), + Target = c(10, 20, 30, 40, 50, 60) + ) + + diff_tbl <- tibble::tibble( + Diff_Value1 = NA_real_, + Diff_Value2 = NA_real_ + ) + + result <- undifference_recipe(recipe_data, diff_tbl, as.Date("2020-06-01")) + expect_equal(result, recipe_data) +}) + +test_that("adjust_column_types converts columns to match recipe", { + data <- tibble::tibble( + x1 = c("1", "2", "3"), + x2 = c(1.0, 2.0, 3.0), + Target = c(10, 20, 30) + ) + + recipe <- recipes::recipe(Target ~ ., data = data.frame( + x1 = 1.0, x2 = "a", Target = 1.0 + )) + recipe <- recipes::prep(recipe, training = data.frame( + x1 = 1.0, x2 = "a", Target = 1.0 + )) + + result <- adjust_column_types(data, recipe) + expect_s3_class(result, "tbl_df") +}) + +test_that("adjust_column_types converts numeric columns", { + data <- tibble::tibble( + x = c("1", "2", "3"), + y = c("a", "b", "c") + ) + + recipe <- recipes::recipe(y ~ x, data = tibble::tibble(x = 1.0, y = "a")) + prepped <- recipes::prep(recipe) + + result <- adjust_column_types(data, prepped) + expect_type(result$x, "double") +}) + +# -- Missing model constructor tests -- + +test_that("meanf creates a workflow", { + train_data <- make_train_data() + result <- meanf(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("snaive creates a workflow", { + train_data <- make_train_data() + result <- snaive(train_data, frequency = 12) + expect_s3_class(result, "workflow") +}) + +test_that("get_recipe_timegpt creates a recipe", { + train_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 12), + Combo = rep("A", 12), + Target = rnorm(12), + xreg1_original = rnorm(12) + ) + + result <- get_recipe_timegpt(train_data) + expect_s3_class(result, "recipe") +}) + +test_that("get_fit_wkflw_nocombo fits without Combo column", { + train_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 24), + Target = rnorm(24, mean = 100, sd = 10) + ) + + recipe_spec <- recipes::recipe(Target ~ Date, data = train_data) + model_spec <- parsnip::linear_reg() %>% parsnip::set_engine("lm") + + result <- get_fit_wkflw_nocombo(train_data, model_spec, recipe_spec) + expect_s3_class(result, "workflow") + expect_true(workflows::is_trained_workflow(result)) +}) + +test_that("nnetar_xregs creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- nnetar_xregs(train_data, frequency = 12, pca = FALSE) + expect_s3_class(result, "workflow") +}) + +test_that("prophet_boost creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- prophet_boost(train_data, pca = FALSE) + expect_s3_class(result, "workflow") +}) + +test_that("prophet_xregs creates a workflow", { + train_data <- make_train_data() + train_data$xreg1 <- rnorm(nrow(train_data)) + result <- prophet_xregs(train_data, pca = FALSE) + expect_s3_class(result, "workflow") +}) + +test_that("undifference_forecast handles multiple Train_Test_IDs", { + original <- c(10, 20, 35, 45, 60, 75) + differenced <- diff(original) + + recipe_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 4), + Target = differenced[1:4] + ) + + forecast_data <- tibble::tibble( + Date = c(as.Date("2020-04-01"), as.Date("2020-05-01")), + Target = differenced[4:5], + Forecast = c(14, 16), + Train_Test_ID = c(1, 2), + Combo = "A" + ) + + diff_tbl <- tibble::tibble( + Combo = "A", + Diff_Value1 = original[1], + Diff_Value2 = NA_real_ + ) + + result <- undifference_forecast(forecast_data, recipe_data, diff_tbl) + expect_equal(length(unique(result$Train_Test_ID)), 2) +}) + +test_that("undifference_recipe handles double differencing", { + original <- c(10, 20, 35, 55, 80, 110, 145, 185) + differenced <- diff(diff(original)) # 6 values + + recipe_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 8), + Target = c(differenced, 7, 8) # 6 hist + 2 future + ) + + diff_tbl <- tibble::tibble( + Combo = "A", + Diff_Value1 = original[1], + Diff_Value2 = original[2] + ) + + result <- undifference_recipe(recipe_data, diff_tbl, as.Date("2020-07-01")) + expect_s3_class(result, "tbl_df") +}) + +test_that("undifference_recipe handles Target_Original column", { + original <- c(10, 20, 35, 45, 60, 75, 90) + differenced <- diff(original) # 6 values + + recipe_data <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 8), + Target = c(differenced, 18, 20), + Target_Original = c(differenced + 1, 19, 21) + ) + + diff_tbl <- tibble::tibble( + Combo = "A", + Diff_Value1 = original[1], + Diff_Value2 = NA_real_ + ) + + result <- undifference_recipe(recipe_data, diff_tbl, as.Date("2020-07-01")) + expect_true("Target_Original" %in% colnames(result)) +}) + +test_that("adjust_column_types coerces factor to character", { + train_data <- tibble::tibble( + Combo = c("A", "B", "C"), + Date = as.Date(c("2020-01-01", "2020-02-01", "2020-03-01")), + Target = c(10, 20, 30) + ) + + recipe_spec <- recipes::recipe(Target ~ ., data = train_data) + + # Intentionally change a column type + test_data <- train_data + test_data$Combo <- factor(test_data$Combo) + + result <- adjust_column_types(test_data, recipe_spec) + expect_type(result$Combo, "character") +}) + +test_that("negative_fcst_adj preserves positive forecasts", { + fcst_tbl <- tibble::tibble( + Combo = "A", + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 3), + Target = c(10, 20, 30), + Train_Test_ID = rep(1, 3), + Forecast = c(12, 22, 28) + ) + + result <- negative_fcst_adj(fcst_tbl, negative_forecast = FALSE) + expect_equal(result$Forecast, c(12, 22, 28)) +}) diff --git a/tests/testthat/test-multistep_models.R b/tests/testthat/test-multistep_models.R new file mode 100644 index 00000000..675805ac --- /dev/null +++ b/tests/testthat/test-multistep_models.R @@ -0,0 +1,367 @@ +# tests/testthat/test-multistep_models.R +# Tests for multistep model spec/print/update/translate functions + +# GLMNET Multistep ---- + +test_that("glmnet_multistep creates model spec", { + spec <- glmnet_multistep( + mode = "regression", + mixture = 0.5, + penalty = 0.01 + ) + expect_s3_class(spec, "glmnet_multistep") + expect_equal(spec$mode, "regression") +}) + +test_that("print.glmnet_multistep outputs text", { + spec <- glmnet_multistep(mode = "regression") + expect_output(print(spec), "GLMNET Multistep Horizon") +}) + +test_that("update.glmnet_multistep updates parameters", { + spec <- glmnet_multistep(mode = "regression", mixture = 0.5) + updated <- update(spec, mixture = 0.8) + expect_s3_class(updated, "glmnet_multistep") +}) + +test_that("update.glmnet_multistep fresh=TRUE replaces args", { + spec <- glmnet_multistep(mode = "regression", mixture = 0.5, penalty = 0.01) + updated <- update(spec, mixture = 0.9, fresh = TRUE) + expect_s3_class(updated, "glmnet_multistep") +}) + +test_that("translate.glmnet_multistep sets engine", { + spec <- glmnet_multistep(mode = "regression") %>% + parsnip::set_engine("glmnet_multistep_horizon") + translated <- translate(spec) + expect_s3_class(translated, "glmnet_multistep") +}) + +test_that("translate.glmnet_multistep uses default engine", { + spec <- glmnet_multistep(mode = "regression") + expect_message( + translated <- translate(spec), + "glmnet_multistep_horizon" + ) + expect_s3_class(translated, "glmnet_multistep") +}) + +# MARS Multistep ---- + +test_that("mars_multistep creates model spec", { + spec <- mars_multistep( + mode = "regression", + num_terms = 10, + prod_degree = 2 + ) + expect_s3_class(spec, "mars_multistep") + expect_equal(spec$mode, "regression") +}) + +test_that("print.mars_multistep outputs text", { + spec <- mars_multistep(mode = "regression") + expect_output(print(spec), "MARS Multistep Horizon") +}) + +test_that("update.mars_multistep updates parameters", { + spec <- mars_multistep(mode = "regression", num_terms = 10) + updated <- update(spec, num_terms = 20) + expect_s3_class(updated, "mars_multistep") +}) + +test_that("update.mars_multistep fresh=TRUE replaces args", { + spec <- mars_multistep(mode = "regression", num_terms = 10, prod_degree = 2) + updated <- update(spec, num_terms = 15, fresh = TRUE) + expect_s3_class(updated, "mars_multistep") +}) + +test_that("translate.mars_multistep sets engine", { + spec <- mars_multistep(mode = "regression") %>% + parsnip::set_engine("mars_multistep_horizon") + translated <- translate(spec) + expect_s3_class(translated, "mars_multistep") +}) + +test_that("translate.mars_multistep uses default engine", { + spec <- mars_multistep(mode = "regression") + expect_message( + translated <- translate(spec), + "mars_multistep_horizon" + ) + expect_s3_class(translated, "mars_multistep") +}) + +# SVM-POLY Multistep ---- + +test_that("svm_poly_multistep creates model spec", { + spec <- svm_poly_multistep( + mode = "regression", + cost = 1, + degree = 2 + ) + expect_s3_class(spec, "svm_poly_multistep") + expect_equal(spec$mode, "regression") +}) + +test_that("print.svm_poly_multistep outputs text", { + spec <- svm_poly_multistep(mode = "regression") + expect_output(print(spec), "SVM-POLY Multistep Horizon") +}) + +test_that("update.svm_poly_multistep updates parameters", { + spec <- svm_poly_multistep(mode = "regression", cost = 1) + updated <- update(spec, cost = 2) + expect_s3_class(updated, "svm_poly_multistep") +}) + +test_that("update.svm_poly_multistep fresh=TRUE replaces args", { + spec <- svm_poly_multistep(mode = "regression", cost = 1, degree = 2) + updated <- update(spec, cost = 3, fresh = TRUE) + expect_s3_class(updated, "svm_poly_multistep") +}) + +test_that("translate.svm_poly_multistep sets engine", { + spec <- svm_poly_multistep(mode = "regression") %>% + parsnip::set_engine("svm_poly_multistep_horizon") + translated <- translate(spec) + expect_s3_class(translated, "svm_poly_multistep") +}) + +test_that("translate.svm_poly_multistep uses default engine", { + spec <- svm_poly_multistep(mode = "regression") + expect_message( + translated <- translate(spec), + "svm_poly_multistep_horizon" + ) + expect_s3_class(translated, "svm_poly_multistep") +}) + +# SVM-RBF Multistep ---- + +test_that("svm_rbf_multistep creates model spec", { + spec <- svm_rbf_multistep( + mode = "regression", + cost = 1, + rbf_sigma = 0.01 + ) + expect_s3_class(spec, "svm_rbf_multistep") + expect_equal(spec$mode, "regression") +}) + +test_that("print.svm_rbf_multistep outputs text", { + spec <- svm_rbf_multistep(mode = "regression") + expect_output(print(spec), "SVM-RBF Multistep Horizon") +}) + +test_that("update.svm_rbf_multistep updates parameters", { + spec <- svm_rbf_multistep(mode = "regression", cost = 1) + updated <- update(spec, cost = 2) + expect_s3_class(updated, "svm_rbf_multistep") +}) + +test_that("update.svm_rbf_multistep fresh=TRUE replaces args", { + spec <- svm_rbf_multistep(mode = "regression", cost = 1, rbf_sigma = 0.01) + updated <- update(spec, cost = 3, fresh = TRUE) + expect_s3_class(updated, "svm_rbf_multistep") +}) + +test_that("translate.svm_rbf_multistep sets engine", { + spec <- svm_rbf_multistep(mode = "regression") %>% + parsnip::set_engine("svm_rbf_multistep_horizon") + translated <- translate(spec) + expect_s3_class(translated, "svm_rbf_multistep") +}) + +test_that("translate.svm_rbf_multistep uses default engine", { + spec <- svm_rbf_multistep(mode = "regression") + expect_message( + translated <- translate(spec), + "svm_rbf_multistep_horizon" + ) + expect_s3_class(translated, "svm_rbf_multistep") +}) + +# XGBOOST Multistep ---- + +test_that("xgboost_multistep creates model spec", { + spec <- xgboost_multistep( + mode = "regression", + tree_depth = 6, + trees = 100 + ) + expect_s3_class(spec, "xgboost_multistep") + expect_equal(spec$mode, "regression") +}) + +test_that("print.xgboost_multistep outputs text", { + spec <- xgboost_multistep(mode = "regression") + expect_output(print(spec), "XGBoost Multistep Horizon") +}) + +test_that("update.xgboost_multistep updates parameters", { + spec <- xgboost_multistep(mode = "regression", trees = 100) + updated <- update(spec, trees = 200) + expect_s3_class(updated, "xgboost_multistep") +}) + +test_that("update.xgboost_multistep fresh=TRUE replaces args", { + spec <- xgboost_multistep(mode = "regression", trees = 100, tree_depth = 6) + updated <- update(spec, trees = 300, fresh = TRUE) + expect_s3_class(updated, "xgboost_multistep") +}) + +test_that("translate.xgboost_multistep sets engine", { + spec <- xgboost_multistep(mode = "regression") %>% + parsnip::set_engine("xgboost_multistep_horizon") + translated <- translate(spec) + expect_s3_class(translated, "xgboost_multistep") +}) + +test_that("translate.xgboost_multistep uses default engine", { + spec <- xgboost_multistep(mode = "regression") + expect_message( + translated <- translate(spec), + "xgboost_multistep_horizon" + ) + expect_s3_class(translated, "xgboost_multistep") +}) + +test_that("multi_future_xreg_check returns NULL when no external regressors", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Combo = "A", + Target = 100 + ) + + result <- multi_future_xreg_check(data, NULL) + expect_null(result) +}) + +test_that("multi_future_xreg_check returns NULL when no xregs match columns", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Combo = "A", + Target = 100 + ) + + result <- multi_future_xreg_check(data, c("nonexistent_col")) + expect_null(result) +}) + +test_that("multi_future_xreg_check returns matching regressors", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Combo = "A", + Target = 100, + xreg1 = 50, + xreg2 = 25 + ) + + result <- multi_future_xreg_check(data, c("xreg1", "xreg2")) + expect_equal(result, c("xreg1", "xreg2")) +}) + +test_that("multi_future_xreg_check returns only matching regressors", { + data <- tibble::tibble( + Date = as.Date("2020-01-01"), + Combo = "A", + Target = 100, + xreg1 = 50 + ) + + result <- multi_future_xreg_check(data, c("xreg1", "missing_xreg")) + expect_equal(result, c("xreg1")) +}) + +test_that("get_multi_lags returns lags up to and including min above horizon", { + lag_periods <- c(1, 2, 3, 6, 12) + forecast_horizon <- 3 + + result <- get_multi_lags(lag_periods, forecast_horizon) + + # should include all lags from 1 to the first lag >= forecast_horizon (3) + expect_equal(result, c(1, 2, 3)) +}) + +test_that("get_multi_lags with horizon matching a lag period", { + lag_periods <- c(1, 3, 6, 12) + forecast_horizon <- 6 + + result <- get_multi_lags(lag_periods, forecast_horizon) + expect_equal(result, c(1, 3, 6)) +}) + +test_that("get_multi_lags with horizon smaller than smallest lag", { + lag_periods <- c(3, 6, 12) + forecast_horizon <- 1 + + result <- get_multi_lags(lag_periods, forecast_horizon) + expect_equal(result, 3) +}) + +test_that("multi_feature_selection selects correct columns without future xregs", { + data <- tibble::tibble( + Combo = "A", + Target = 100, + Date_num = 1, + lag3_val = 50, + lag6_val = 40, + lag12_val = 30 + ) + + result <- multi_feature_selection( + data, + future_xregs = NULL, + lag_periods = c(3, 6, 12), + lag = 6, + target = FALSE + ) + + expect_true("lag6_val" %in% colnames(result)) + expect_true("lag12_val" %in% colnames(result)) + expect_false("Combo" %in% colnames(result)) + expect_false("Target" %in% colnames(result)) +}) + +test_that("multi_feature_selection includes Combo and Target when target = TRUE", { + data <- tibble::tibble( + Combo = "A", + Target = 100, + Date_num = 1, + lag3_val = 50, + lag6_val = 40, + lag12_val = 30 + ) + + result <- multi_feature_selection( + data, + future_xregs = NULL, + lag_periods = c(3, 6, 12), + lag = 3, + target = TRUE + ) + + expect_true("Combo" %in% colnames(result)) + expect_true("Target" %in% colnames(result)) +}) + +test_that("multi_feature_selection includes future xregs when provided", { + data <- tibble::tibble( + Combo = "A", + Target = 100, + Date_num = 1, + lag3_val = 50, + lag6_val = 40, + xreg1 = 10 + ) + + result <- multi_feature_selection( + data, + future_xregs = c("xreg1"), + lag_periods = c(3, 6), + lag = 3, + target = TRUE + ) + + expect_true("xreg1" %in% colnames(result)) +}) diff --git a/tests/testthat/test-parallel_util.R b/tests/testthat/test-parallel_util.R new file mode 100644 index 00000000..407cbd89 --- /dev/null +++ b/tests/testthat/test-parallel_util.R @@ -0,0 +1,191 @@ +# tests/testthat/test-parallel_util.R + +test_that("get_cores returns cores minus 1 when num_cores is NULL", { + result <- get_cores(NULL) + + expect_equal(result, parallel::detectCores() - 1) +}) + +test_that("get_cores respects num_cores limit", { + result <- get_cores(2) + + expect_true(result <= 2) + expect_true(result <= parallel::detectCores() - 1) +}) + +test_that("get_cores caps at available cores minus 1", { + result <- get_cores(1000) + + expect_equal(result, parallel::detectCores() - 1) +}) + +test_that("par_start returns sequential operator when parallel_processing is NULL", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = NULL, + num_cores = NULL, + task_length = 5 + ) + + expect_type(par_info, "list") + expect_true("packages" %in% names(par_info)) + expect_true("foreach_operator" %in% names(par_info)) + expect_true("cl" %in% names(par_info)) + expect_null(par_info$cl) + expect_true(is.function(par_info$foreach_operator)) +}) + +test_that("par_start returns correct packages for sequential processing", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = NULL, + num_cores = NULL, + task_length = 5 + ) + + expect_true("dplyr" %in% par_info$packages) + expect_true("tibble" %in% par_info$packages) + expect_true("recipes" %in% par_info$packages) +}) + +test_that("par_start with local_machine creates cluster", { + skip_if(parallel::detectCores() <= 1, "Requires more than one core") + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = "local_machine", + num_cores = 2, + task_length = 5 + ) + + expect_false(is.null(par_info$cl)) + expect_true(is.function(par_info$foreach_operator)) + + # clean up + par_end(par_info$cl) +}) + +test_that("par_end cleans up cluster", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = "local_machine", + num_cores = 2, + task_length = 3 + ) + + # should not error + expect_no_error(par_end(par_info$cl)) +}) + +test_that("par_end handles NULL cluster gracefully", { + expect_no_error(par_end(NULL)) +}) + +test_that("cancel_parallel handles none parallel processing", { + par_info <- list( + parallel_processing = NULL, + cl = NULL + ) + + expect_no_error(cancel_parallel(par_info)) +}) + +test_that("cancel_parallel handles local_machine cleanup", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = "local_machine", + num_cores = 2, + task_length = 3 + ) + + expect_no_error(cancel_parallel(par_info)) +}) + +test_that("par_start errors on invalid parallel_processing input", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + expect_error( + par_start( + run_info = run_info, + parallel_processing = "invalid", + num_cores = NULL, + task_length = 5 + ), + "error" + ) +}) + +test_that("par_start adds parquet package when data_output is parquet", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "parquet", + object_output = "rds" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = NULL, + num_cores = NULL, + task_length = 5 + ) + + expect_true("arrow" %in% par_info$packages) +}) + +test_that("par_start adds qs2 package when object_output is qs2", { + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "qs2" + ) + + par_info <- par_start( + run_info = run_info, + parallel_processing = NULL, + num_cores = NULL, + task_length = 5 + ) + + expect_true("qs2" %in% par_info$packages) +}) diff --git a/tests/testthat/test-prep_data_helpers.R b/tests/testthat/test-prep_data_helpers.R new file mode 100644 index 00000000..2d9e9790 --- /dev/null +++ b/tests/testthat/test-prep_data_helpers.R @@ -0,0 +1,516 @@ +# tests/testthat/test-prep_data_helpers.R +# Tests for helper functions in prep_data.R + +test_that("get_frequency_number returns correct values", { + expect_equal(get_frequency_number("year"), 1) + expect_equal(get_frequency_number("quarter"), 4) + expect_equal(get_frequency_number("month"), 12) + expect_equal(get_frequency_number("week"), 365.25 / 7, tolerance = 1e-5) + expect_equal(get_frequency_number("day"), 365.25) +}) + +test_that("get_fourier_periods returns defaults per date_type", { + result_year <- get_fourier_periods(NULL, "year") + expect_equal(result_year, c(1, 2, 3, 4, 5)) + + result_quarter <- get_fourier_periods(NULL, "quarter") + expect_equal(result_quarter, c(1, 2, 3, 4)) + + result_month <- get_fourier_periods(NULL, "month") + expect_equal(result_month, c(3, 6, 9, 12)) + + result_week <- get_fourier_periods(NULL, "week") + expect_equal(result_week, c(2, 4, 8, 12, 24, 48, 52)) + + result_day <- get_fourier_periods(NULL, "day") + expect_equal(result_day, c(7, 14, 21, 28, 56, 84, 168, 252, 336, 365)) +}) + +test_that("get_fourier_periods returns custom when provided", { + custom <- c(5, 10, 15) + result <- get_fourier_periods(custom, "month") + expect_equal(result, custom) +}) + +test_that("get_lag_periods returns defaults for each date_type", { + result_month <- get_lag_periods(NULL, "month", 3) + expect_true(3 %in% result_month) + expect_true(all(result_month >= 3)) + + result_year <- get_lag_periods(NULL, "year", 1) + expect_true(1 %in% result_year) +}) + +test_that("get_lag_periods returns custom when provided", { + custom <- c(1, 5, 10) + result <- get_lag_periods(custom, "month", 3) + expect_equal(result, custom) +}) + +test_that("get_lag_periods multistep monthly", { + result <- get_lag_periods(NULL, "month", 3, multistep_horizon = TRUE) + expect_true(is.numeric(result)) + expect_true(length(result) > 0) +}) + +test_that("get_lag_periods multistep weekly", { + result <- get_lag_periods(NULL, "week", 4, multistep_horizon = TRUE) + expect_true(is.numeric(result)) + expect_true(length(result) > 0) +}) + +test_that("get_lag_periods multistep daily", { + result <- get_lag_periods(NULL, "day", 14, multistep_horizon = TRUE) + expect_true(is.numeric(result)) + expect_true(length(result) > 0) +}) + +test_that("get_lag_periods multistep appends forecast_horizon if needed", { + # large forecast_horizon beyond max default + result <- get_lag_periods(NULL, "month", 24, multistep_horizon = TRUE) + expect_true(24 %in% result) +}) + +test_that("get_lag_periods feature_engineering daily", { + result <- get_lag_periods(NULL, "day", 14, multistep_horizon = TRUE, feature_engineering = TRUE) + expect_true(is.numeric(result)) + expect_true(length(result) >= 3) +}) + +test_that("get_rolling_window_periods returns defaults", { + result_month <- get_rolling_window_periods(NULL, "month") + expect_equal(result_month, c(3, 6, 9, 12)) + + result_year <- get_rolling_window_periods(NULL, "year") + expect_equal(result_year, c(2, 3, 4, 5)) +}) + +test_that("get_rolling_window_periods returns custom when provided", { + custom <- c(2, 4) + result <- get_rolling_window_periods(custom, "month") + expect_equal(result, custom) +}) + +test_that("get_recipes_to_run returns defaults per date_type", { + expect_equal(get_recipes_to_run(NULL, "month"), c("R1", "R2")) + expect_equal(get_recipes_to_run(NULL, "quarter"), c("R1", "R2")) + expect_equal(get_recipes_to_run(NULL, "year"), c("R1", "R2")) + expect_equal(get_recipes_to_run(NULL, "week"), c("R1")) + expect_equal(get_recipes_to_run(NULL, "day"), c("R1")) +}) + +test_that("get_recipes_to_run returns custom when provided", { + custom <- c("R1", "R2", "R3") + result <- get_recipes_to_run(custom, "month") + expect_equal(result, custom) +}) + +test_that("get_date_regex returns correct regex", { + result_year <- get_date_regex("year") + expect_true(grepl("quarter", result_year)) + expect_true(grepl("month", result_year)) + expect_true(grepl("week", result_year)) + + result_day <- get_date_regex("day") + expect_true(grepl("hour", result_day)) + expect_true(grepl("minute", result_day)) + expect_false(grepl("month", result_day)) +}) + +test_that("apply_box_cox transforms data correctly", { + set.seed(123) + df <- tibble::tibble( + Combo = rep("A", 20), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 20), + Target = abs(rnorm(20, mean = 100, sd = 20)) + ) + + result <- apply_box_cox(df) + + expect_type(result, "list") + expect_true("data" %in% names(result)) + expect_true("diff_info" %in% names(result)) + expect_s3_class(result$data, "tbl_df") + expect_true("Box_Cox_Lambda" %in% colnames(result$diff_info)) +}) + +test_that("apply_box_cox preserves non-numeric columns", { + df <- tibble::tibble( + Combo = rep("A", 10), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Target = 1:10 * 10.0 + ) + + result <- apply_box_cox(df) + + expect_true("Combo" %in% colnames(result$data)) + expect_true("Date" %in% colnames(result$data)) +}) + +test_that("combo_cleanup_fn removes combos with zero target", { + df <- tibble::tibble( + Combo = c("A", "A", "A", "B", "B", "B"), + Date = rep(seq.Date(as.Date("2020-01-01"), by = "month", length.out = 3), 2), + Target = c(10, 20, 30, 0, 0, 0) + ) + + result <- combo_cleanup_fn( + df, + combo_cleanup_date = as.Date("2020-01-01"), + hist_end_date = as.Date("2020-03-01") + ) + + expect_equal(unique(result$Combo), "A") +}) + +test_that("combo_cleanup_fn returns all when no cleanup date", { + df <- tibble::tibble( + Combo = c("A", "B"), + Date = as.Date(c("2020-01-01", "2020-01-01")), + Target = c(10, 0) + ) + + result <- combo_cleanup_fn(df, NULL, as.Date("2020-01-01")) + + expect_equal(nrow(result), 2) +}) + +# -- get_date_regex additional branch tests -- + +test_that("get_date_regex returns correct regex for quarter", { + result <- get_date_regex("quarter") + expect_true(grepl("month", result)) + expect_true(grepl("week", result)) + expect_true(grepl("day", result)) + # quarter should NOT be in the regex for quarter date_type + expect_false(grepl("\\(quarter\\)", result)) +}) + +test_that("get_date_regex returns correct regex for month", { + + result <- get_date_regex("month") + expect_true(grepl("week", result)) + expect_true(grepl("day", result)) + # month should NOT be in the regex for month date_type + expect_false(grepl("\\(month\\)", result)) +}) + +test_that("get_date_regex returns correct regex for week", { + result <- get_date_regex("week") + expect_true(grepl("day", result)) + expect_true(grepl("hour", result)) + # week should NOT be in the regex for week date_type + expect_false(grepl("\\(week\\)", result)) +}) + +# -- get_rolling_window_periods additional tests -- + +test_that("get_rolling_window_periods returns defaults for week", { + result <- get_rolling_window_periods(NULL, "week") + expect_equal(result, c(2, 4, 8, 12, 24, 48, 52)) +}) + +test_that("get_rolling_window_periods returns defaults for day", { + result <- get_rolling_window_periods(NULL, "day") + expect_equal(result, c(7, 14, 21, 28, 56, 84, 168, 252, 336, 365)) +}) + +test_that("get_rolling_window_periods returns defaults for quarter", { + result <- get_rolling_window_periods(NULL, "quarter") + expect_equal(result, c(2, 3, 4)) +}) + +# -- apply_box_cox additional tests -- + +test_that("apply_box_cox with Target_Original uses shared lambda", { + set.seed(42) + df <- tibble::tibble( + Combo = rep("A", 20), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 20), + Target = abs(rnorm(20, mean = 100, sd = 20)), + Target_Original = abs(rnorm(20, mean = 100, sd = 20)) + ) + + result <- apply_box_cox(df) + expect_true("Target_Original" %in% colnames(result$data)) + expect_false(is.na(result$diff_info$Box_Cox_Lambda)) +}) + +test_that("apply_box_cox skips constant target", { + df <- tibble::tibble( + Combo = rep("A", 10), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Target = rep(5, 10) # only 1 unique value (<=2) + ) + + result <- apply_box_cox(df) + expect_true(is.na(result$diff_info$Box_Cox_Lambda)) + # Target unchanged since skipped + expect_equal(result$data$Target, rep(5, 10)) +}) + +test_that("apply_box_cox skips columns with 2 or fewer unique values", { + df <- tibble::tibble( + Combo = rep("A", 10), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Target = abs(rnorm(10, mean = 50, sd = 10)), + Binary_Xreg = rep(c(0, 1), 5) # only 2 unique values + ) + + result <- apply_box_cox(df) + # Binary_Xreg should be untouched + expect_equal(result$data$Binary_Xreg, rep(c(0, 1), 5)) +}) + +# -- make_stationary additional tests -- + +test_that("make_stationary returns unchanged data when already stationary", { + set.seed(123) + df <- tibble::tibble( + Combo = rep("A", 30), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 30), + Target = rnorm(30, mean = 0, sd = 1) # white noise, already stationary + ) + + result <- make_stationary(df) + expect_type(result, "list") + expect_true("data" %in% names(result)) + expect_true("diff_info" %in% names(result)) +}) + +test_that("make_stationary differences non-stationary data", { + set.seed(42) + # Create random walk (non-stationary) + df <- tibble::tibble( + Combo = rep("A", 50), + Date = seq.Date(as.Date("2016-01-01"), by = "month", length.out = 50), + Target = cumsum(rnorm(50)) + ) + + result <- make_stationary(df) + expect_type(result, "list") + expect_true("Diff_Value1" %in% colnames(result$diff_info)) + # If differencing occurred, Diff_Value1 should be the first Target value + if (!is.na(result$diff_info$Diff_Value1)) { + expect_equal(result$diff_info$Diff_Value1, df$Target[1]) + } +}) + +test_that("make_stationary preserves non-numeric columns", { + set.seed(123) + df <- tibble::tibble( + Combo = rep("A", 20), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 20), + Target = cumsum(rnorm(20)), + Category = rep("cat1", 20) + ) + + result <- make_stationary(df) + expect_true("Category" %in% colnames(result$data)) +}) + +test_that("make_stationary with Target_Original uses same ndiffs as Target", { + set.seed(42) + df <- tibble::tibble( + Combo = rep("A", 50), + Date = seq.Date(as.Date("2016-01-01"), by = "month", length.out = 50), + Target = cumsum(rnorm(50)), + Target_Original = cumsum(rnorm(50)) + ) + + result <- make_stationary(df) + expect_true("Target_Original" %in% colnames(result$data)) +}) + +# -- clean_outliers_missing_values additional tests -- + +test_that("clean_outliers_missing_values handles column with <2 non-NA values", { + df <- tibble::tibble( + Combo = rep("A", 10), + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Target = c(10, 20, 30, 40, 50, 60, 70, 80, 90, 100), + Sparse_Xreg = c(5, rep(NA, 9)) # only 1 non-NA value + ) + + result <- clean_outliers_missing_values( + df, + clean_outliers = FALSE, + clean_missing_values = TRUE, + frequency_number = 12, + external_regressors = "Sparse_Xreg" + ) + + expect_s3_class(result, "tbl_df") +}) + +# -- align_types tests -- + +test_that("align_types converts integer to match numeric", { + df1 <- data.frame(x = 1.5, y = "a", stringsAsFactors = FALSE) + df2 <- data.frame(x = 1L, y = "b", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_type(result$x, "double") + expect_equal(result$x, 1.0) +}) + +test_that("align_types converts numeric to match integer", { + df1 <- data.frame(x = 1L, y = "a", stringsAsFactors = FALSE) + df2 <- data.frame(x = 2.0, y = "b", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_type(result$x, "integer") +}) + +test_that("align_types converts character to match factor", { + df1 <- data.frame(x = factor("a")) + df2 <- data.frame(x = "a", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_s3_class(result$x, "factor") +}) + +test_that("align_types converts to match Date", { + df1 <- data.frame(x = as.Date("2020-01-01")) + df2 <- data.frame(x = "2020-01-01", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_s3_class(result$x, "Date") + expect_equal(result$x, as.Date("2020-01-01")) +}) + +test_that("align_types converts to match logical", { + df1 <- data.frame(x = TRUE) + df2 <- data.frame(x = 1L) + result <- align_types(df1, df2) + expect_type(result$x, "logical") + expect_true(result$x) +}) + +test_that("align_types converts to match character", { + df1 <- data.frame(x = "abc", stringsAsFactors = FALSE) + df2 <- data.frame(x = 123) + result <- align_types(df1, df2) + expect_type(result$x, "character") + expect_equal(result$x, "123") +}) + +test_that("align_types only affects shared columns", { + df1 <- data.frame(x = 1.0, stringsAsFactors = FALSE) + df2 <- data.frame(x = 1L, extra = "keep", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_type(result$x, "double") + expect_equal(result$extra, "keep") +}) + +test_that("align_types handles POSIXct conversion", { + df1 <- data.frame(x = as.POSIXct("2020-01-01 12:00:00", tz = "UTC")) + df2 <- data.frame(x = "2020-01-01 12:00:00", stringsAsFactors = FALSE) + result <- align_types(df1, df2) + expect_s3_class(result$x, "POSIXct") +}) + +# -- get_xregs_future_values_tbl tests -- + +test_that("get_xregs_future_values_tbl returns matching columns", { + data_tbl <- tibble::tibble( + Combo = rep("A", 10), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 10), + xreg1 = 1:10, + xreg2 = c(rep(NA, 5), 6:10) + ) + hist_end_date <- as.Date("2020-05-01") + result <- get_xregs_future_values_tbl( + data_tbl, c("xreg1", "xreg2"), hist_end_date + ) + expect_true("Combo" %in% names(result)) + expect_true("Date" %in% names(result)) + expect_true("xreg1" %in% names(result)) + expect_true("xreg2" %in% names(result)) +}) + +test_that("get_xregs_future_values_tbl excludes all-NA future regressors", { + data_tbl <- tibble::tibble( + Combo = rep("A", 10), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 10), + xreg1 = 1:10, + xreg2 = c(1:5, rep(NA, 5)) + ) + hist_end_date <- as.Date("2020-05-01") + result <- get_xregs_future_values_tbl( + data_tbl, c("xreg1", "xreg2"), hist_end_date + ) + expect_true("xreg1" %in% names(result)) + expect_false("xreg2" %in% names(result)) +}) + +test_that("get_xregs_future_values_tbl with no external regressors", { + data_tbl <- tibble::tibble( + Combo = rep("A", 5), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 5) + ) + result <- get_xregs_future_values_tbl( + data_tbl, character(0), as.Date("2020-03-01") + ) + expect_equal(names(result), c("Combo", "Date")) +}) + +# -- Additional clean_outliers_missing_values tests -- + +test_that("clean_outliers_missing_values with no cleaning", { + df <- tibble::tibble( + Target = c(1, 2, 3, NA, 5, 6, 7, 8, 9, 10), + xreg1 = 1:10 + ) + result <- clean_outliers_missing_values( + df, + clean_outliers = FALSE, + clean_missing_values = FALSE, + frequency_number = 1, + external_regressors = "xreg1" + ) + expect_equal(result$Target, df$Target) +}) + +test_that("clean_outliers_missing_values imputes missing values", { + df <- tibble::tibble( + Target = c(1, 2, NA, 4, 5, 6, 7, 8, 9, 10), + xreg1 = 1:10 + ) + result <- clean_outliers_missing_values( + df, + clean_outliers = FALSE, + clean_missing_values = TRUE, + frequency_number = 1, + external_regressors = "xreg1" + ) + expect_false(any(is.na(result$Target))) +}) + +test_that("clean_outliers_missing_values cleans outliers", { + set.seed(123) + df <- tibble::tibble( + Target = c(rep(5, 20), 500, rep(5, 9)), + xreg1 = 1:30 + ) + result <- clean_outliers_missing_values( + df, + clean_outliers = TRUE, + clean_missing_values = FALSE, + frequency_number = 1, + external_regressors = "xreg1" + ) + expect_true("Target_Original" %in% names(result)) + # outlier should be cleaned + expect_true(result$Target[21] < 500) +}) + +# -- Additional make_stationary test -- + +test_that("make_stationary handles binary numeric columns", { + df <- tibble::tibble( + Combo = rep("A", 30), + Date = seq(as.Date("2020-01-01"), by = "month", length.out = 30), + Target = rnorm(30), + Binary = rep(c(0, 1), 15) + ) + result <- make_stationary(df) + # binary column should not be differenced (only 2 unique values) + expect_equal(result$data$Binary, df$Binary) +}) diff --git a/tests/testthat/test-prep_models_helpers.R b/tests/testthat/test-prep_models_helpers.R new file mode 100644 index 00000000..7cf3adb7 --- /dev/null +++ b/tests/testthat/test-prep_models_helpers.R @@ -0,0 +1,139 @@ +# tests/testthat/test-prep_models_helpers.R +# Tests for helper functions in prep_models.R + +test_that("get_back_test_spacing returns custom when provided", { + result <- get_back_test_spacing(3, "month") + expect_equal(result, 3) +}) + +test_that("get_back_test_spacing returns 1 for monthly data", { + result <- get_back_test_spacing(NULL, "month") + expect_equal(result, 1) +}) + +test_that("get_back_test_spacing returns 1 for yearly data", { + result <- get_back_test_spacing(NULL, "year") + expect_equal(result, 1) +}) + +test_that("get_back_test_spacing returns 1 for quarterly data", { + result <- get_back_test_spacing(NULL, "quarter") + expect_equal(result, 1) +}) + +test_that("get_back_test_spacing returns 4 for weekly data", { + result <- get_back_test_spacing(NULL, "week") + expect_equal(result, 4) +}) + +test_that("get_back_test_spacing returns 7 for daily data", { + result <- get_back_test_spacing(NULL, "day") + expect_equal(result, 7) +}) + +test_that("get_frequency_number prep_models returns correct values", { + # prep_models has its own version of get_frequency_number + # testing through prep_models namespace + expect_equal(get_frequency_number("year"), 1) + expect_equal(get_frequency_number("quarter"), 4) + expect_equal(get_frequency_number("month"), 12) +}) + +test_that("get_date_type returns correct date type", { + expect_equal(get_date_type(1), "year") + expect_equal(get_date_type(4), "quarter") + expect_equal(get_date_type(12), "month") + expect_equal(get_date_type(52.17857), "week") + expect_equal(get_date_type(365.25), "day") +}) + +test_that("get_seasonal_periods returns correct values", { + result_year <- get_seasonal_periods("year") + expect_equal(result_year, c(1, 2, 3)) + + result_month <- get_seasonal_periods("month") + expect_equal(result_month, c(12, 6, 3)) + + result_quarter <- get_seasonal_periods("quarter") + expect_equal(result_quarter, c(4, 2, 8)) +}) + +test_that("get_back_test_scenario_hist_periods computes correctly", { + input_tbl <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 50), + Combo = "A", + Target = rnorm(50) + ) + + result <- get_back_test_scenario_hist_periods( + input_tbl, + hist_end_date = as.Date("2024-02-01"), + forecast_horizon = 3, + back_test_scenarios = NULL, + back_test_spacing = 1 + ) + + expect_true("hist_periods_80" %in% names(result)) + expect_true("back_test_scenarios" %in% names(result)) + expect_true(result$hist_periods_80 > 0) + expect_true(result$back_test_scenarios > 0) +}) + +test_that("get_back_test_scenario_hist_periods uses custom scenarios", { + input_tbl <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 50), + Combo = "A", + Target = rnorm(50) + ) + + result <- get_back_test_scenario_hist_periods( + input_tbl, + hist_end_date = as.Date("2024-02-01"), + forecast_horizon = 3, + back_test_scenarios = 5, + back_test_spacing = 1 + ) + + expect_equal(result$back_test_scenarios, 6) # scenarios + 1 +}) + +# -- get_frequency_number additional tests -- + +test_that("get_frequency_number returns correct values for week and day", { + expect_equal(get_frequency_number("week"), 52.17857) + expect_equal(get_frequency_number("day"), 365.25) +}) + +# -- get_seasonal_periods additional tests -- + +test_that("get_seasonal_periods returns correct values for week", { + result <- get_seasonal_periods("week") + expect_equal(result, c(365.25 / 7, (365.25 / 7) / 4, (365.25 / 7) / 12)) +}) + +test_that("get_seasonal_periods returns correct values for day", { + result <- get_seasonal_periods("day") + expect_equal(result, c(365.25, 365.25 / 4, 365.25 / 12)) +}) + +# -- get_back_test_scenario_hist_periods edge cases -- + +test_that("get_back_test_scenario_hist_periods with short data", { + input_tbl <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 8), + Combo = "A", + Target = rnorm(8) + ) + + result <- get_back_test_scenario_hist_periods( + input_tbl, + hist_end_date = as.Date("2020-08-01"), + forecast_horizon = 3, + back_test_scenarios = NULL, + back_test_spacing = 1 + ) + + expect_type(result, "list") + expect_true("back_test_scenarios" %in% names(result)) + expect_true(result$back_test_scenarios >= 1) +}) diff --git a/tests/testthat/test-project_info.R b/tests/testthat/test-project_info.R new file mode 100644 index 00000000..6e2e1e9c --- /dev/null +++ b/tests/testthat/test-project_info.R @@ -0,0 +1,257 @@ +# tests/testthat/test-project_info.R + +test_that("set_project_info creates project info with defaults", { + project_info <- set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + expect_type(project_info, "list") + expect_true("project_name" %in% names(project_info)) + expect_true("combo_variables" %in% names(project_info)) + expect_true("target_variable" %in% names(project_info)) + expect_true("date_type" %in% names(project_info)) + expect_true("path" %in% names(project_info)) + expect_true("created" %in% names(project_info)) + expect_true("data_output" %in% names(project_info)) + expect_true("object_output" %in% names(project_info)) + expect_equal(project_info$project_name, "finn_project") + expect_equal(project_info$data_output, "csv") + expect_equal(project_info$object_output, "rds") + expect_equal(project_info$combo_variables, c("id")) + expect_equal(project_info$target_variable, "value") + expect_equal(project_info$date_type, "month") +}) + +test_that("set_project_info creates project info with custom values", { + project_info <- set_project_info( + project_name = "my_project", + combo_variables = c("Store", "Product"), + target_variable = "Sales", + date_type = "week", + fiscal_year_start = 7, + weekly_to_daily = FALSE, + data_output = "parquet", + object_output = "qs2" + ) + + expect_equal(project_info$project_name, "my_project") + expect_equal(project_info$combo_variables, c("Store", "Product")) + expect_equal(project_info$target_variable, "Sales") + expect_equal(project_info$date_type, "week") + expect_equal(project_info$fiscal_year_start, 7) + expect_equal(project_info$weekly_to_daily, FALSE) + expect_equal(project_info$data_output, "parquet") + expect_equal(project_info$object_output, "qs2") +}) + +test_that("set_project_info errors on invalid data_output", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month", + data_output = "json" + ), + "invalid value for input name 'data_output'" + ) +}) + +test_that("set_project_info errors on invalid object_output", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month", + object_output = "pickle" + ), + "invalid value for input name 'object_output'" + ) +}) + +test_that("set_project_info errors on invalid project_name type", { + expect_error( + set_project_info( + project_name = 123, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ), + "`project_name` must either be a NULL or a string" + ) +}) + +test_that("set_project_info errors on invalid date_type", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "hourly" + ), + "invalid value for input name 'date_type'" + ) +}) + +test_that("set_project_info accepts all valid date types", { + for (dt in c("year", "quarter", "month", "week", "day")) { + project_info <- set_project_info( + project_name = paste0("date_type_test_", dt), + combo_variables = c("id"), + target_variable = "value", + date_type = dt + ) + expect_equal(project_info$date_type, dt) + } +}) + +test_that("set_project_info errors on invalid combo_variables type", { + expect_error( + set_project_info( + combo_variables = 123, + target_variable = "value", + date_type = "month" + ), + "invalid type for input name 'combo_variables'" + ) +}) + +test_that("set_project_info errors on invalid target_variable type", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = 123, + date_type = "month" + ), + "invalid type for input name 'target_variable'" + ) +}) + +test_that("set_project_info errors on invalid weekly_to_daily type", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month", + weekly_to_daily = "yes" + ), + "invalid type for input name 'weekly_to_daily'" + ) +}) + +test_that("set_project_info errors on invalid path type", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month", + path = 123 + ), + "`path` must either be a NULL or a string" + ) +}) + +test_that("set_project_info errors on invalid storage_object type", { + expect_error( + set_project_info( + combo_variables = c("id"), + target_variable = "value", + date_type = "month", + storage_object = "invalid" + ), + "`storage_object` must either be a NULL" + ) +}) + +test_that("set_project_info with explicit path creates directories", { + temp_dir <- tempfile("project_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + project_info <- set_project_info( + path = temp_dir, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + expect_true(dir.exists(file.path(temp_dir, "eda"))) + expect_true(dir.exists(file.path(temp_dir, "input_data"))) + expect_true(dir.exists(file.path(temp_dir, "logs"))) + expect_true(dir.exists(file.path(temp_dir, "final_output"))) + expect_true(dir.exists(file.path(temp_dir, "prep_data"))) + expect_true(dir.exists(file.path(temp_dir, "prep_models"))) + expect_true(dir.exists(file.path(temp_dir, "models"))) + expect_true(dir.exists(file.path(temp_dir, "forecasts"))) +}) + +test_that("set_project_info with overwrite re-creates project", { + temp_dir <- tempfile("project_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + project_info1 <- set_project_info( + project_name = "overwrite_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + # second call with same inputs should use existing + project_info2 <- set_project_info( + project_name = "overwrite_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + expect_equal(project_info2$project_name, "overwrite_test") +}) + +test_that("set_project_info errors when inputs change without overwrite", { + temp_dir <- tempfile("project_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + set_project_info( + project_name = "change_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + expect_error( + set_project_info( + project_name = "change_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "revenue", + date_type = "month" + ), + "Inputs have recently changed" + ) +}) + +test_that("set_project_info overwrite allows changing inputs", { + temp_dir <- tempfile("project_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + set_project_info( + project_name = "overwrite_change_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "value", + date_type = "month" + ) + + project_info <- set_project_info( + project_name = "overwrite_change_test", + path = temp_dir, + combo_variables = c("id"), + target_variable = "revenue", + date_type = "month", + overwrite = TRUE + ) + + expect_equal(project_info$target_variable, "revenue") +}) diff --git a/tests/testthat/test-read_write_data.R b/tests/testthat/test-read_write_data.R new file mode 100644 index 00000000..48c35981 --- /dev/null +++ b/tests/testthat/test-read_write_data.R @@ -0,0 +1,444 @@ +# tests/testthat/test-read_write_data.R +# Tests for read_write_data.R functions + +# -- hash_data tests -- + +test_that("hash_data produces consistent hashes", { + hash1 <- hash_data("test_string") + hash2 <- hash_data("test_string") + + expect_equal(hash1, hash2) + expect_type(hash1, "character") + expect_true(nchar(hash1) > 0) +}) + +test_that("hash_data produces different hashes for different inputs", { + hash1 <- hash_data("hello") + hash2 <- hash_data("world") + + expect_false(hash1 == hash2) +}) + +test_that("hash_data works with various types", { + expect_type(hash_data("string"), "character") + expect_type(hash_data(42), "character") + expect_type(hash_data(list(a = 1, b = 2)), "character") + expect_type(hash_data(data.frame(x = 1:3)), "character") +}) + +# -- write_data_type tests -- + +test_that("write_data_type writes csv files", { + temp_file <- tempfile(fileext = ".csv") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + write_data_type(df, temp_file, "csv") + + expect_true(file.exists(temp_file)) + read_back <- vroom::vroom(temp_file, delim = ",", show_col_types = FALSE) + expect_equal(nrow(read_back), 5) +}) + +test_that("write_data_type writes rds files", { + temp_file <- tempfile(fileext = ".rds") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + write_data_type(df, temp_file, "rds") + + expect_true(file.exists(temp_file)) + read_back <- readRDS(temp_file) + expect_equal(nrow(read_back), 5) +}) + +test_that("write_data_type writes single-row csv as log format", { + temp_file <- tempfile(fileext = ".csv") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1, b = "x") + write_data_type(df, temp_file, "csv") + + expect_true(file.exists(temp_file)) + read_back <- utils::read.csv(temp_file) + expect_equal(nrow(read_back), 1) +}) + +# -- write_data tests -- + +test_that("write_data writes to temp directory when path is NULL", { + run_info <- list( + project_name = "test_proj", + run_name = "test_run", + storage_object = NULL, + path = NULL, + data_output = "csv", + object_output = "rds" + ) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + + expect_no_error( + write_data( + x = df, + combo = "test_combo", + run_info = run_info, + output_type = "data", + folder = "forecasts", + suffix = "-test" + ) + ) +}) + +test_that("write_data writes to explicit path", { + temp_dir <- tempfile("write_data_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + run_info <- list( + project_name = "test_proj", + run_name = "test_run", + storage_object = NULL, + path = temp_dir, + data_output = "csv", + object_output = "rds" + ) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + + write_data( + x = df, + combo = "test_combo", + run_info = run_info, + output_type = "data", + folder = "forecasts", + suffix = "-test" + ) + + files <- list.files(file.path(temp_dir, "forecasts"), pattern = "\\.csv$") + expect_true(length(files) > 0) +}) + +test_that("write_data handles NULL combo", { + run_info <- list( + project_name = "test_proj", + run_name = "test_run", + storage_object = NULL, + path = NULL, + data_output = "csv", + object_output = "rds" + ) + + df <- tibble::tibble(a = 1, b = "x") + + expect_no_error( + write_data( + x = df, + combo = NULL, + run_info = run_info, + output_type = "log", + folder = "logs", + suffix = NULL + ) + ) +}) + +test_that("write_data handles object output type", { + temp_dir <- tempfile("write_data_obj_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + run_info <- list( + project_name = "test_proj", + run_name = "test_run", + storage_object = NULL, + path = temp_dir, + data_output = "csv", + object_output = "rds" + ) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + + write_data( + x = df, + combo = "test_combo", + run_info = run_info, + output_type = "object", + folder = "models", + suffix = "-test" + ) + + files <- list.files(file.path(temp_dir, "models"), pattern = "\\.rds$") + expect_true(length(files) > 0) +}) + +# -- custom_ls tests -- + +test_that("custom_ls lists files in local directory", { + temp_dir <- tempdir() + test_dir <- file.path(temp_dir, "test_custom_ls") + dir.create(test_dir, showWarnings = FALSE, recursive = TRUE) + on.exit(unlink(test_dir, recursive = TRUE), add = TRUE) + + file.create(file.path(test_dir, "file1.csv")) + file.create(file.path(test_dir, "file2.csv")) + file.create(file.path(test_dir, "file3.txt")) + + result <- custom_ls(file.path(test_dir, "*.csv")) + expect_equal(length(result), 2) + expect_true(all(grepl("\\.csv$", result))) +}) + +test_that("custom_ls returns empty for non-existent pattern", { + temp_dir <- tempdir() + test_dir <- file.path(temp_dir, "test_custom_ls2") + dir.create(test_dir, showWarnings = FALSE, recursive = TRUE) + on.exit(unlink(test_dir, recursive = TRUE), add = TRUE) + + file.create(file.path(test_dir, "file1.csv")) + + result <- custom_ls(file.path(test_dir, "*.parquet")) + expect_equal(length(result), 0) +}) + +test_that("custom_ls with wildcard glob matches all files", { + temp_dir <- tempdir() + test_dir <- file.path(temp_dir, "test_custom_ls3") + dir.create(test_dir, showWarnings = FALSE, recursive = TRUE) + on.exit(unlink(test_dir, recursive = TRUE), add = TRUE) + + file.create(file.path(test_dir, "a.csv")) + file.create(file.path(test_dir, "b.txt")) + + result <- custom_ls(file.path(test_dir, "*")) + expect_equal(length(result), 2) +}) + +test_that("custom_ls validates input is character", { + expect_error(custom_ls(123)) +}) + +test_that("custom_ls validates input is length 1", { + expect_error(custom_ls(c("a", "b"))) +}) + +# -- list_files tests -- + +test_that("list_files with NULL storage_object lists local files", { + temp_dir <- tempdir() + test_dir <- file.path(temp_dir, "test_list_files") + dir.create(test_dir, showWarnings = FALSE, recursive = TRUE) + on.exit(unlink(test_dir, recursive = TRUE), add = TRUE) + + file.create(file.path(test_dir, "data1.csv")) + file.create(file.path(test_dir, "data2.csv")) + + result <- list_files(NULL, file.path(test_dir, "*.csv")) + expect_equal(length(result), 2) + expect_true(all(grepl("\\.csv$", result))) +}) + +# -- write_data_folder tests -- + +test_that("write_data_folder writes to local with NULL storage", { + temp_file <- tempfile(fileext = ".csv") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(x = 1:5, y = letters[1:5]) + write_data_folder( + x = df, + storage_object = NULL, + final_dest = "NULL", + temp_path = NULL, + final_path = temp_file, + type = "csv" + ) + expect_true(file.exists(temp_file)) +}) + +# -- write_data_type parquet tests -- + +test_that("write_data_type writes parquet files", { + skip_if_not_installed("arrow") + temp_file <- tempfile(fileext = ".parquet") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + write_data_type(df, temp_file, "parquet") + + expect_true(file.exists(temp_file)) + read_back <- arrow::read_parquet(temp_file) + expect_equal(nrow(read_back), 5) +}) + +# -- write_data_type qs2 tests -- + +test_that("write_data_type writes qs2 files", { + skip_if_not_installed("qs2") + temp_file <- tempfile(fileext = ".qs2") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + write_data_type(df, temp_file, "qs2") + + expect_true(file.exists(temp_file)) + read_back <- qs2::qs_read(temp_file) + expect_equal(nrow(read_back), 5) +}) + +# -- read_file tests (local storage) -- + +test_that("read_file reads csv file with file_list parameter", { + temp_file <- tempfile(fileext = ".csv") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(x = 1:5, y = letters[1:5]) + vroom::vroom_write(df, temp_file, delim = ",", progress = FALSE) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + result <- read_file(run_info, file_list = temp_file) + expect_equal(nrow(result), 5) + expect_true("x" %in% colnames(result)) +}) + +test_that("read_file reads rds file with file_list parameter", { + temp_file <- tempfile(fileext = ".rds") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:3, b = c("x", "y", "z")) + saveRDS(df, temp_file) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "rds", + object_output = "rds" + ) + + result <- read_file(run_info, file_list = temp_file) + expect_equal(nrow(result), 3) + expect_true("a" %in% colnames(result)) +}) + +test_that("read_file reads parquet file with file_list parameter", { + skip_if_not_installed("arrow") + temp_file <- tempfile(fileext = ".parquet") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:4, b = letters[1:4]) + arrow::write_parquet(df, temp_file) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "parquet", + object_output = "rds" + ) + + result <- read_file(run_info, file_list = temp_file) + expect_equal(nrow(result), 4) +}) + +test_that("read_file with return_type='object' reads rds", { + temp_file <- tempfile(fileext = ".rds") + on.exit(unlink(temp_file), add = TRUE) + + obj <- list(model = "test", value = 42) + saveRDS(obj, temp_file) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + result <- read_file(run_info, file_list = temp_file, return_type = "object") + expect_type(result, "list") + expect_equal(result$model, "test") +}) + +test_that("read_file with return_type='arrow' opens arrow dataset", { + skip("arrow::open_dataset version incompatibility in test environment") + skip_if_not_installed("arrow") + temp_file <- tempfile(fileext = ".parquet") + on.exit(unlink(temp_file), add = TRUE) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + arrow::write_parquet(df, temp_file) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "parquet", + object_output = "rds" + ) + + result <- read_file(run_info, file_list = temp_file, return_type = "arrow") + expect_s3_class(result, "Dataset") +}) + +test_that("read_file reads multiple csv files", { + temp_dir <- file.path(tempdir(), "multi_csv_test") + dir.create(temp_dir, showWarnings = FALSE, recursive = TRUE) + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + df1 <- tibble::tibble(x = 1:3, y = c("a", "b", "c")) + df2 <- tibble::tibble(x = 4:6, y = c("d", "e", "f")) + vroom::vroom_write(df1, file.path(temp_dir, "file1.csv"), delim = ",", progress = FALSE) + vroom::vroom_write(df2, file.path(temp_dir, "file2.csv"), delim = ",", progress = FALSE) + + run_info <- list( + storage_object = NULL, + path = tempdir(), + data_output = "csv", + object_output = "rds" + ) + + files <- c(file.path(temp_dir, "file1.csv"), file.path(temp_dir, "file2.csv")) + result <- read_file(run_info, file_list = files) + expect_equal(nrow(result), 6) +}) + +# -- write_data with parquet output -- + +test_that("write_data writes parquet to explicit path", { + skip_if_not_installed("arrow") + temp_dir <- tempfile("write_data_parquet_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + run_info <- list( + project_name = "test_proj", + run_name = "test_run", + storage_object = NULL, + path = temp_dir, + data_output = "parquet", + object_output = "rds" + ) + + df <- tibble::tibble(a = 1:5, b = letters[1:5]) + + write_data( + x = df, + combo = "test_combo", + run_info = run_info, + output_type = "data", + folder = "forecasts", + suffix = "-test" + ) + + files <- list.files(file.path(temp_dir, "forecasts"), pattern = "\\.parquet$") + expect_true(length(files) > 0) +}) + +# -- list_files edge cases -- + +test_that("list_files returns exact path when no wildcard", { + result <- list_files(NULL, "/some/path/file.csv") + expect_equal(result, "/some/path/file.csv") +}) + diff --git a/tests/testthat/test-run_info.R b/tests/testthat/test-run_info.R new file mode 100644 index 00000000..487242b8 --- /dev/null +++ b/tests/testthat/test-run_info.R @@ -0,0 +1,216 @@ +# tests/testthat/test-run_info.R + +test_that("set_run_info creates run info with defaults", { + run_info <- set_run_info() + + expect_type(run_info, "list") + expect_true("project_name" %in% names(run_info)) + expect_true("run_name" %in% names(run_info)) + expect_true("created" %in% names(run_info)) + expect_true("path" %in% names(run_info)) + expect_true("data_output" %in% names(run_info)) + expect_true("object_output" %in% names(run_info)) + expect_equal(run_info$project_name, "finn_project") + expect_equal(run_info$data_output, "csv") + expect_equal(run_info$object_output, "rds") +}) + +test_that("set_run_info creates run info with custom values", { + run_info <- set_run_info( + project_name = "my_project", + run_name = "test_run", + data_output = "parquet", + object_output = "qs2" + ) + + expect_equal(run_info$project_name, "my_project") + expect_true(grepl("test_run", run_info$run_name)) + expect_equal(run_info$data_output, "parquet") + expect_equal(run_info$object_output, "qs2") +}) + +test_that("set_run_info appends unique id by default", { + run_info <- set_run_info(run_name = "test_run") + + # run_name should be test_run followed by a timestamp + expect_true(grepl("^test_run-", run_info$run_name)) +}) + +test_that("set_run_info without unique id", { + run_info <- set_run_info( + run_name = "fixed_run", + add_unique_id = FALSE + ) + + expect_equal(run_info$run_name, "fixed_run") +}) + +test_that("set_run_info errors on invalid data_output", { + expect_error( + set_run_info(data_output = "json"), + "invalid value for input name 'data_output'" + ) +}) + +test_that("set_run_info errors on invalid object_output", { + expect_error( + set_run_info(object_output = "pickle"), + "invalid value for input name 'object_output'" + ) +}) + +test_that("set_run_info errors on invalid run_name type", { + expect_error( + set_run_info(run_name = 123), + "`run_name` must either be a NULL or a string" + ) +}) + +test_that("set_run_info errors on invalid path type", { + expect_error( + set_run_info(path = 123), + "`path` must either be a NULL or a string" + ) +}) + +test_that("set_run_info errors on invalid storage_object type", { + expect_error( + set_run_info(storage_object = "not_valid"), + "`storage_object` must either be a NULL" + ) +}) + +test_that("set_run_info with explicit path creates correct directories", { + temp_dir <- tempfile("run_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + run_info <- set_run_info(path = temp_dir) + + expect_true(dir.exists(file.path(temp_dir, "prep_data"))) + expect_true(dir.exists(file.path(temp_dir, "prep_models"))) + expect_true(dir.exists(file.path(temp_dir, "models"))) + expect_true(dir.exists(file.path(temp_dir, "forecasts"))) +}) + +test_that("set_run_info creates storage_object in output list", { + run_info <- set_run_info() + + expect_true("storage_object" %in% names(run_info)) + expect_null(run_info$storage_object) +}) + +test_that("get_run_info errors on invalid run_name type", { + expect_error( + get_run_info(run_name = 123), + "`run_name` must either be a NULL or a string" + ) +}) + +test_that("get_run_info errors on invalid path type", { + expect_error( + get_run_info(path = 123), + "`path` must either be a NULL or a string" + ) +}) + +test_that("get_run_info errors on invalid storage_object type", { + expect_error( + get_run_info(storage_object = "not_valid"), + "`storage_object` must either be a NULL" + ) +}) + +test_that("get_run_info returns run data after set_run_info", { + temp_dir <- tempfile("run_info_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + run_info <- set_run_info( + project_name = "test_proj", + run_name = "my_run", + path = temp_dir, + add_unique_id = FALSE + ) + + result <- get_run_info( + project_name = "test_proj", + run_name = "my_run", + path = temp_dir + ) + + expect_s3_class(result, "data.frame") + expect_true(nrow(result) > 0) + expect_equal(result$project_name, "test_proj") + expect_equal(result$run_name, "my_run") +}) + +test_that("set_run_info with add_unique_id=FALSE reuses existing log", { + temp_dir <- tempfile("run_info_reuse_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + ri1 <- set_run_info( + project_name = "p", + run_name = "r", + path = temp_dir, + add_unique_id = FALSE + ) + + ri2 <- set_run_info( + project_name = "p", + run_name = "r", + path = temp_dir, + add_unique_id = FALSE + ) + + expect_equal(ri1$created, ri2$created) + expect_equal(ri1$path, ri2$path) +}) + +test_that("set_run_info with add_unique_id=FALSE errors on changed inputs", { + temp_dir <- tempfile("run_info_changed_test") + on.exit(unlink(temp_dir, recursive = TRUE), add = TRUE) + + set_run_info( + project_name = "p", + run_name = "r", + path = temp_dir, + add_unique_id = FALSE + ) + + expect_error( + set_run_info( + project_name = "p", + run_name = "r", + path = temp_dir, + data_output = "parquet", + add_unique_id = FALSE + ), + "Inputs have recently changed" + ) +}) + +# -- utility.R tests -- + +test_that("get_timestamp returns a POSIXct object in UTC", { + ts <- get_timestamp() + + expect_s3_class(ts, "POSIXct") + expect_equal(attr(ts, "tzone"), "UTC") +}) + +test_that("get_timestamp returns a timestamp close to current time", { + ts <- get_timestamp() + now_utc <- as.POSIXct(format(Sys.time(), tz = "UTC"), tz = "UTC") + + # should be within 5 seconds of now + + expect_true(abs(difftime(ts, now_utc, units = "secs")) < 5) +}) + +test_that("get_timestamp format is YYYYMMDDTHHMMSSZ", { + ts <- get_timestamp() + formatted <- format(ts, "%Y%m%dT%H%M%SZ") + + # should be parseable back + parsed <- as.POSIXct(formatted, format = "%Y%m%dT%H%M%SZ", tz = "UTC") + expect_false(is.na(parsed)) +}) diff --git a/tests/testthat/test-timegpt.R b/tests/testthat/test-timegpt.R index 95c75619..1e159839 100644 --- a/tests/testthat/test-timegpt.R +++ b/tests/testthat/test-timegpt.R @@ -24,7 +24,7 @@ test_that("TimeGPT API key validation", { api_key <- Sys.getenv("NIXTLA_API_KEY") # Normalize URL - azure_url <- finnts:::normalize_url(azure_url) + azure_url <- normalize_url(azure_url) Sys.setenv(NIXTLA_BASE_URL = azure_url) # Setup client @@ -604,7 +604,7 @@ test_that("pad_time_series_data preserves original y and external regressor valu original_date_range <- range(original_dates) # Pad the data (monthly data, min_size = 48) - padded_df <- finnts:::pad_time_series_data(train_df, date_type = "month", min_size = 48) + padded_df <- pad_time_series_data(train_df, date_type = "month", min_size = 48) # Verify padding occurred (should have >= 48 rows) expect_true(nrow(padded_df) >= 48, @@ -653,7 +653,7 @@ test_that("pad_time_series_data preserves original y and external regressor valu temperature_original = rnorm(60, mean = 20, sd = 2) ) - not_padded_df <- finnts:::pad_time_series_data(large_df, date_type = "month", min_size = 48) + not_padded_df <- pad_time_series_data(large_df, date_type = "month", min_size = 48) expect_equal(nrow(not_padded_df), 60, info = "Data above minimum size should not be padded" @@ -680,7 +680,7 @@ test_that("pad_time_series_data works with multiple combos", { original_m2 <- train_df %>% dplyr::filter(Combo == "M2") # Pad the data - padded_df <- finnts:::pad_time_series_data(train_df, date_type = "month", min_size = 48) + padded_df <- pad_time_series_data(train_df, date_type = "month", min_size = 48) # Verify both combos were padded combo_counts <- padded_df %>% @@ -892,11 +892,11 @@ test_that("TimeGPT uses long-horizon model for monthly forecasts > 24 months", { skip_if_not(has_timegpt_credentials(), "NIXTLA credentials not set") # First verify the helper function detects long horizon correctly - expect_true(finnts:::is_long_horizon_forecast(25, "month"), + expect_true(is_long_horizon_forecast(25, "month"), info = "Helper function: 25 months should be detected as long horizon (> 24)" ) - expect_false(finnts:::is_long_horizon_forecast(12, "month"), + expect_false(is_long_horizon_forecast(12, "month"), info = "Helper function: 12 months should NOT be detected as long horizon (< 24)" ) diff --git a/tests/testthat/test-timegpt_helpers.R b/tests/testthat/test-timegpt_helpers.R new file mode 100644 index 00000000..1370b475 --- /dev/null +++ b/tests/testthat/test-timegpt_helpers.R @@ -0,0 +1,266 @@ +# tests/testthat/test-timegpt_helpers.R +# Tests for pure helper functions in timegpt_model.R that don't require API credentials + +# -- is_azure_url tests -- + +test_that("is_azure_url returns TRUE for Azure URLs", { + + expect_true(is_azure_url("https://my-resource.azure.com/v1")) + expect_true(is_azure_url("https://azure.example.com/api")) + expect_true(is_azure_url("https://my.azure-model.com/")) + expect_true(is_azure_url("http://test.azure.net/endpoint")) +}) + +test_that("is_azure_url returns FALSE for non-Azure URLs", { + expect_false(is_azure_url("https://api.nixtla.io/v1")) + expect_false(is_azure_url("https://example.com/api")) + expect_false(is_azure_url("https://my-endpoint.openai.com")) +}) + +test_that("is_azure_url returns FALSE for NULL or empty string", { + expect_false(is_azure_url(NULL)) + expect_false(is_azure_url("")) +}) + +test_that("is_azure_url is case-insensitive", { + expect_true(is_azure_url("https://my.AZURE.com/api")) + expect_true(is_azure_url("https://my.Azure.COM/api")) +}) + +# -- normalize_url tests -- + +test_that("normalize_url adds trailing slash if missing", { + expect_warning( + result <- normalize_url("https://api.example.com"), + "did not end with" + ) + expect_equal(result, "https://api.example.com/") +}) + +test_that("normalize_url keeps trailing slash if present", { + result <- normalize_url("https://api.example.com/") + expect_equal(result, "https://api.example.com/") +}) + +test_that("normalize_url warns when appending slash", { + expect_warning(normalize_url("https://api.example.com"), "NIXTLA_BASE_URL") +}) + +# -- get_timegpt_min_size tests -- + +test_that("get_timegpt_min_size returns correct size for each date_type", { + expect_equal(get_timegpt_min_size("day"), 300) + expect_equal(get_timegpt_min_size("week"), 64) + expect_equal(get_timegpt_min_size("month"), 48) + expect_equal(get_timegpt_min_size("quarter"), 48) + expect_equal(get_timegpt_min_size("year"), 48) +}) + +test_that("get_timegpt_min_size returns 48 for NULL", { + expect_equal(get_timegpt_min_size(NULL), 48) +}) + +test_that("get_timegpt_min_size returns 48 for unknown date_type", { + expect_equal(get_timegpt_min_size("unknown"), 48) +}) + +# -- is_long_horizon_forecast tests -- + +test_that("is_long_horizon_forecast returns TRUE for long horizons", { + expect_true(is_long_horizon_forecast(15, "day")) + expect_true(is_long_horizon_forecast(105, "week")) + expect_true(is_long_horizon_forecast(25, "month")) + expect_true(is_long_horizon_forecast(9, "quarter")) + expect_true(is_long_horizon_forecast(3, "year")) +}) + +test_that("is_long_horizon_forecast returns FALSE for short horizons", { + expect_false(is_long_horizon_forecast(14, "day")) + expect_false(is_long_horizon_forecast(104, "week")) + expect_false(is_long_horizon_forecast(24, "month")) + expect_false(is_long_horizon_forecast(8, "quarter")) + expect_false(is_long_horizon_forecast(2, "year")) +}) + +test_that("is_long_horizon_forecast returns FALSE for NULL inputs", { + expect_false(is_long_horizon_forecast(NULL, "month")) + expect_false(is_long_horizon_forecast(10, NULL)) + expect_false(is_long_horizon_forecast(NULL, NULL)) +}) + +test_that("is_long_horizon_forecast uses default threshold for unknown date_type", { + expect_true(is_long_horizon_forecast(25, "unknown")) + expect_false(is_long_horizon_forecast(24, "unknown")) +}) + +# -- pad_time_series_data tests -- + +test_that("pad_time_series_data returns unchanged when min_size is NULL", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Combo = rep("A", 10), + y = rnorm(10) + ) + result <- pad_time_series_data(df, "month", min_size = NULL) + expect_equal(nrow(result), 10) +}) + +test_that("pad_time_series_data returns unchanged when date_type is NULL", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 10), + Combo = rep("A", 10), + y = rnorm(10) + ) + result <- pad_time_series_data(df, NULL, min_size = 48) + expect_equal(nrow(result), 10) +}) + +test_that("pad_time_series_data returns unchanged when data already meets min_size", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 50), + Combo = rep("A", 50), + y = rnorm(50) + ) + result <- pad_time_series_data(df, "month", min_size = 48) + expect_equal(nrow(result), 50) +}) + +test_that("pad_time_series_data pads monthly data to min_size", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 20), + Combo = rep("A", 20), + y = rnorm(20) + ) + result <- pad_time_series_data(df, "month", min_size = 30) + expect_true(nrow(result) >= 30) + expect_true(all(result$Combo == "A")) +}) + +test_that("pad_time_series_data fills padded rows with zeros", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2022-01-01"), by = "month", length.out = 5), + Combo = rep("A", 5), + y = rep(100, 5), + xreg = rep(50, 5) + ) + result <- pad_time_series_data(df, "month", min_size = 10) + # padded rows should have y = 0 + padded_rows <- result %>% dplyr::filter(Date < as.Date("2022-01-01")) + expect_true(all(padded_rows$y == 0)) + expect_true(all(padded_rows$xreg == 0)) +}) + +test_that("pad_time_series_data handles daily data", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2023-01-01"), by = "day", length.out = 100), + Combo = rep("A", 100), + y = rnorm(100) + ) + result <- pad_time_series_data(df, "day", min_size = 300) + expect_true(nrow(result) >= 300) +}) + +test_that("pad_time_series_data handles weekly data", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2023-01-01"), by = "week", length.out = 30), + Combo = rep("A", 30), + y = rnorm(30) + ) + result <- pad_time_series_data(df, "week", min_size = 64) + expect_true(nrow(result) >= 64) +}) + +test_that("pad_time_series_data handles multiple combos", { + df <- rbind( + tibble::tibble( + Date = seq.Date(as.Date("2022-01-01"), by = "month", length.out = 48), + Combo = rep("A", 48), + y = rnorm(48) + ), + tibble::tibble( + Date = seq.Date(as.Date("2023-01-01"), by = "month", length.out = 10), + Combo = rep("B", 10), + y = rnorm(10) + ) + ) + result <- pad_time_series_data(df, "month", min_size = 48) + # Combo A already has 48 rows so no padding, Combo B should be padded + a_rows <- result %>% dplyr::filter(Combo == "A") + b_rows <- result %>% dplyr::filter(Combo == "B") + expect_equal(nrow(a_rows), 48) + expect_true(nrow(b_rows) >= 48) +}) + +test_that("pad_time_series_data errors on invalid date_type", { + df <- tibble::tibble( + Date = seq.Date(as.Date("2020-01-01"), by = "month", length.out = 5), + Combo = rep("A", 5), + y = rnorm(5) + ) + expect_error(pad_time_series_data(df, "invalid_type", min_size = 10), "Unsupported date_type") +}) + +# -- finetune_steps tests -- + +test_that("finetune_steps creates a dials parameter", { + param <- finetune_steps() + expect_s3_class(param, "quant_param") + expect_equal(param$type, "integer") +}) + +test_that("finetune_steps uses default range", { + param <- finetune_steps() + expect_equal(param$range$lower, 0L) + expect_equal(param$range$upper, 200L) +}) + +test_that("finetune_steps accepts custom range", { + param <- finetune_steps(range = c(10L, 50L)) + expect_equal(param$range$lower, 10L) + expect_equal(param$range$upper, 50L) +}) + +# -- finetune_depth tests -- + +test_that("finetune_depth creates a dials parameter", { + param <- finetune_depth() + expect_s3_class(param, "quant_param") + expect_equal(param$type, "integer") +}) + +test_that("finetune_depth uses default range", { + param <- finetune_depth() + expect_equal(param$range$lower, 1L) + expect_equal(param$range$upper, 5L) +}) + +test_that("finetune_depth accepts custom range", { + param <- finetune_depth(range = c(2L, 10L)) + expect_equal(param$range$lower, 2L) + expect_equal(param$range$upper, 10L) +}) + +# -- timegpt_model spec tests -- + +test_that("timegpt_model creates model spec", { + model <- timegpt_model(forecast_horizon = 6) + expect_s3_class(model, "timegpt_model") + expect_equal(model$mode, "regression") +}) + +test_that("timegpt_model prints without error", { + model <- timegpt_model(forecast_horizon = 6) + expect_output(print(model), "Main Arguments") +}) + +test_that("timegpt_model update works", { + model <- timegpt_model(forecast_horizon = 6) + updated <- update(model, forecast_horizon = 12) + expect_s3_class(updated, "timegpt_model") +}) + +test_that("timegpt_model update with fresh=TRUE replaces args", { + model <- timegpt_model(forecast_horizon = 6, frequency = "M") + updated <- update(model, forecast_horizon = 3, fresh = TRUE) + expect_s3_class(updated, "timegpt_model") +}) diff --git a/tests/testthat/test-utility.R b/tests/testthat/test-utility.R new file mode 100644 index 00000000..b8d78be7 --- /dev/null +++ b/tests/testthat/test-utility.R @@ -0,0 +1,25 @@ +# tests/testthat/test-utility.R + +test_that("get_timestamp returns a POSIXct object in UTC", { + ts <- get_timestamp() + + expect_s3_class(ts, "POSIXct") + expect_equal(attr(ts, "tzone"), "UTC") +}) + +test_that("get_timestamp returns a timestamp close to current time", { + ts <- get_timestamp() + now_utc <- as.POSIXct(format(Sys.time(), tz = "UTC"), tz = "UTC") + + # should be within 60 seconds of now to avoid flaky failures on slow CI + expect_true(abs(difftime(ts, now_utc, units = "secs")) < 60) +}) + +test_that("get_timestamp format is YYYYMMDDTHHMMSSZ", { + ts <- get_timestamp() + formatted <- format(ts, "%Y%m%dT%H%M%SZ") + + # should be parseable back + parsed <- as.POSIXct(formatted, format = "%Y%m%dT%H%M%SZ", tz = "UTC") + expect_false(is.na(parsed)) +})