diff --git a/DESCRIPTION b/DESCRIPTION
index 904ee4f..9abb785 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: SDModels
Title: Spectrally Deconfounded Models
-Version: 2.0.1
+Version: 2.0.2
Authors@R: c(
person("Markus", "Ulmer", email = "markus.ulmer@stat.math.ethz.ch",
role = c("aut", "cre", "cph"), comment = c(ORCID = "0000-0001-7783-8475")),
@@ -16,12 +16,13 @@ Imports:
ggraph,
gridExtra,
parallel,
- pbapply,
Rdpack,
tidyr,
fda,
grplasso,
- rlang
+ rlang,
+ progressr,
+ parallelly
Suggests:
plotly,
datasets,
@@ -31,6 +32,7 @@ Suggests:
ranger,
HDclassif,
qpdf,
+ cli,
testthat (>= 3.0.0)
RdMacros: Rdpack
Encoding: UTF-8
diff --git a/NEWS.md b/NEWS.md
index 03671ab..00376d5 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,3 +1,9 @@
+# SDModels 2.0.2
+
+* Switch all the parallelization to futures. See `vignette("Runtime")`
+* Switch all the progress updates to progressr. Progress updates are now also available for parallel processing and are customizable.
+* Process are much more RAM efficient now.
+
# SDModels 2.0.1
* Fix bug in SDTree and SDForest where an error occurred, if X had columns with only one unique value.
diff --git a/R/SDAM.R b/R/SDAM.R
index 527b80a..4195e90 100644
--- a/R/SDAM.R
+++ b/R/SDAM.R
@@ -36,9 +36,11 @@
#' Default is \code{TRUE} to not reduce the signal of high variance covariates.
#' @param ind_lin A vector of indices specifying which covariates to model linearly (i.e. not expanded into basis function).
#' Default is `NULL`.
-#' @param mc.cores Number of cores to use for parallel processing, if \code{mc.cores > 1}
-#' the cross validation is parallelized. Default is `1`. (only supported for unix)
-#' @param verbose If \code{TRUE} fitting information is shown.
+#' @param mc.cores Number of cores to use for parallel computation `vignette("Runtime")`.
+#' The `future` package is used for parallel processing.
+#' To use custom processing plans mc.cores has to be <= 1, see [`future` package](https://future.futureverse.org/).
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
#' @param notRegularized A vector of indices specifying which covariates not to regularize.
#' Default is `NULL`.
#' @return An object of class `SDAM` containing the following elements:
@@ -98,7 +100,8 @@
#' # predict
#' predict(model, newdata = wine[42, ])
#'
-#' ## alternative function call
+#' ## alternative function call with customized progress bar
+#' progressr::handlers(progressr::handler_txtprogressbar(char = cli::col_red(cli::symbol$heart)))
#' mod_none <- SDAM(x = as.matrix(wine[1:10, -c(1, 2)]), y = wine$alcohol[1:10],
#' Q_type = "no_deconfounding", nfolds = 2, n_K = 4,
#' n_lambda1 = 4, n_lambda2 = 8)
@@ -156,8 +159,15 @@ SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
n_unique_X <- apply(X, 2, function(x){length(unique(x))})
# Generate the design and model parameters for every K in vK
- lmodK <- list()
- for (i in 1:length(vK)){
+ progressr::with_progress({
+ pr <- progressr::progressor(along = 1:(n_K), enable = verbose)
+ pr(sprintf("Design generation"), amount = 0, class = "sticky")
+ if(mc.cores > 1){
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = min(mc.cores, n_K)), local = TRUE)
+ }
+
+ lmodK <- future.apply::future_lapply(future.seed = TRUE, 1:length(vK), function(i){
K <- vK[i]
# effective number of basis functions for each Xj, j = 1,..., p
# K_eff[j] can be at most equal to the number of unique values of Xj
@@ -213,9 +223,11 @@ SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
lambda <- rep(0, n_lambda1)
index <- rep(1, length(index))
}
- lmodK[[i]] <- list(Rlist = Rlist, lbreaks = lbreaks, index = index, B = B,
- QB = QB, lambda = lambda, K = K, K_eff = K_eff)
- }
+ pr()
+ list(Rlist = Rlist, lbreaks = lbreaks, index = index,
+ QB = QB, lambda = lambda, K = K, K_eff = K_eff)
+ })
+ })
# generate folds for CV
ind <- sample(rep(1:nfolds, length.out = n), replace = FALSE)
@@ -236,20 +248,34 @@ SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
QYpred <- predict(mod, newdata = listK$QB[test, ])
mse <- apply(QYpred, 2, function(y){mean((y - QY[test])^2)})
+ pr()
return(mse)
}
mse_fold <- function(l){
- MSEl <- lapply(lmodK, function(listK){mse_fold_K(l, listK)})
+ MSEl <- future.apply::future_lapply(future.seed = TRUE, lmodK,
+ mse_fold_K,
+ l = l)
return(unname(do.call(rbind, MSEl)))
}
- if(verbose) print("Initial cross-validation")
- if(mc.cores == 1){
- MSES <- pbapply::pblapply(1:nfolds, mse_fold)
- } else {
- MSES <- parallel::mclapply(1:nfolds, mse_fold, mc.cores = mc.cores)
- }
+ #use random generator that works with multiprocessing
+ ok <- RNGkind("L'Ecuyer-CMRG")
+ progressr::with_progress({
+ pr <- progressr::progressor(along = 1:(nfolds * n_K), enable = verbose)
+ pr(sprintf("Initial cross-validation"), amount = 0, class = "sticky")
+ if(mc.cores > 1){
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = min(mc.cores, nfolds)), local = TRUE)
+ }
+ MSES <- lapply(X = 1:nfolds, mse_fold)
+ })
+
+ #if(mc.cores == 1){
+ # MSES <- pbapply::pblapply(1:nfolds, mse_fold)
+ #} else {
+ # MSES <- parallel::mclapply(1:nfolds, mse_fold, mc.cores = mc.cores)
+ #}
# aggregate MSEs over folds
MSES.agg <- Reduce("+", MSES) / nfolds
@@ -267,13 +293,25 @@ SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
length.out = n_lambda2))
}
- if(verbose) print("Second stage cross-validation")
- if(mc.cores == 1){
- MSES1 <- pbapply::pblapply(1:nfolds, mse_fold_K, listK = modK.min)
- } else {
- MSES1 <- parallel::mclapply(1:nfolds, mse_fold_K, listK = modK.min,
- mc.cores = mc.cores)
- }
+ progressr::with_progress({
+ pr <- progressr::progressor(along = 1:nfolds, enable = verbose)
+ pr(sprintf("Second stage cross-validation"), amount = 0, class = "sticky")
+ if(mc.cores > 1){
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = min(mc.cores, nfolds)), local = TRUE)
+ }
+ MSES1 <- future.apply::future_lapply(future.seed = TRUE,
+ X = 1:nfolds,
+ mse_fold_K,
+ listK = modK.min)
+ })
+ #if(verbose) print("Second stage cross-validation")
+ #if(mc.cores == 1){
+ # MSES1 <- pbapply::pblapply(1:nfolds, mse_fold_K, listK = modK.min)
+ #} else {
+ # MSES1 <- parallel::mclapply(1:nfolds, mse_fold_K, listK = modK.min,
+ # mc.cores = mc.cores)
+ #}
MSES1 <- do.call(rbind, MSES1)
MSE1.agg <- apply(MSES1, 2, mean)
@@ -339,6 +377,7 @@ SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
# estimated active set
lreturn$active <- active
class(lreturn) <- "SDAM"
+ RNGkind(ok[1])
return(lreturn)
}
diff --git a/R/SDForest.R b/R/SDForest.R
index 7580dd6..d7d180b 100644
--- a/R/SDForest.R
+++ b/R/SDForest.R
@@ -34,8 +34,9 @@
#' @param mtry Number of randomly selected covariates to consider for a split,
#' if \code{NULL} half of the covariates are available for each split.
#' \eqn{\text{mtry} = \lfloor \frac{p}{2} \rfloor}
-#' @param mc.cores Number of cores to use for parallel processing,
-#' if \code{mc.cores > 1} the trees are estimated in parallel.
+#' @param mc.cores Number of cores to use for parallel computation `vignette("Runtime")`.
+#' The `future` package is used for parallel processing.
+#' To use custom processing plans mc.cores has to be <= 1, see [`future` package](https://future.futureverse.org/).
#' @param Q_type Type of deconfounding, one of 'trim', 'pca', 'no_deconfounding'.
#' 'trim' corresponds to the Trim transform \insertCite{Cevid2020SpectralModels}{SDModels}
#' as implemented in the Doubly debiased lasso \insertCite{Guo2022DoublyConfounding}{SDModels},
@@ -67,7 +68,8 @@
#' @param Q_scale Should data be scaled to estimate the spectral transformation?
#' Default is \code{TRUE} to not reduce the signal of high variance covariates,
#' and we do not know of a scenario where this hurts.
-#' @param verbose If \code{TRUE} fitting information is shown.
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
#' @param predictors Subset of colnames(X) or numerical indices of the covariates
#' for which an effect on y should be estimated. All the other covariates are only
#' used for deconfounding.
@@ -127,6 +129,8 @@
#' # comparison to classical random forest
#' fit_ranger <- ranger::ranger(Y ~ ., train_data, importance = 'impurity')
#'
+#' # you can customize the progress bar see parameter verbose
+#' progressr::handlers("cli")
#' fit <- SDForest(x = X, y = Y, nTree = 100, Q_type = 'pca', q_hat = 2)
#' fit <- SDForest(Y ~ ., nTree = 100, train_data)
#' fit
@@ -137,6 +141,7 @@
#' plot(fit)
#'
#' # a few more might be helpfull
+#' progressr::handlers(progressr::handler_txtprogressbar(char = cli::col_red(cli::symbol$heart)))
#' fit2 <- SDForest(Y ~ ., nTree = 50, train_data)
#' fit <- mergeForest(fit, fit2)
#'
@@ -276,27 +281,21 @@ SDForest <- function(formula = NULL, data = NULL, x = NULL, y = NULL, nTree = 10
ind <- do.call(c, ind)
}
- #use random generater that works with multiprocessing
+ #use random generator that works with multiprocessing
ok <- RNGkind("L'Ecuyer-CMRG")
-
+
# Worker wrapper for bagged trees
worker_fun <- function(i) {
- Xi <- matrix(X[i, ], ncol = ncol(X))
- colnames(Xi) <- colnames(X)
- if(!is.null(A)){
- Ai <- matrix(A[i, ], ncol = ncol(A))
- }else{
- Ai <- NULL
- }
-
# protect SDTree call
res_i <- tryCatch({
- tree_obj <- SDTree(x = Xi, y = Y[i],
- cp = cp, min_sample = min_sample,
+ tree_obj <- estimate_tree(X = X, Y = Y, Qf = NULL,
+ cp = cp, min_sample = min_sample, max_leaves = n,
Q_type = Q_type, trim_quantile = trim_quantile,
- q_hat = q_hat, mtry = mtry, A = Ai, gamma = gamma,
- max_candidates = max_candidates,
- Q_scale = Q_scale, predictors = predictors)
+ q_hat = q_hat, mtry = mtry, A = A, gamma = gamma,
+ max_candidates = max_candidates, fast = TRUE,
+ Q_scale = Q_scale, predictors = predictors,
+ boot_index = i)
+
list(ok = TRUE, tree = tree_obj)
}, error = function(e) {
list(ok = FALSE, error = conditionMessage(e))
@@ -304,24 +303,21 @@ SDForest <- function(formula = NULL, data = NULL, x = NULL, y = NULL, nTree = 10
# convert warnings to tagged results if needed
list(ok = TRUE, tree = NULL, warning = conditionMessage(w))
})
+ p()
+
res_i
}
+ progressr::with_progress({
+ p <- progressr::progressor(along = ind, enable = verbose)
if(mc.cores > 1){
- if(Sys.info()[["sysname"]] == "Linux"){
- if(verbose) print('mclapply')
- res_list <- parallel::mclapply(ind, worker_fun, mc.cores = mc.cores)
- }else{
- if(verbose) print('future')
- future::plan('multisession', workers = mc.cores)
- res_list <- future.apply::future_lapply(future.seed = TRUE, X = ind, worker_fun)
- }
- }else{
- res_list <- pbapply::pblapply(ind, worker_fun)
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = mc.cores), local = TRUE)
}
- RNGkind(ok[1])
+ res_list <- future.apply::future_lapply(future.seed = TRUE, X = ind, worker_fun)
+ })
- #check worker statuses
+ # check worker statuses
failed_workers <- which(vapply(res_list, function(z) !isTRUE(z$ok), logical(1)))
if (length(failed_workers) > 0) {
stop(sprintf("SDForest: %d worker(s) failed, first error: %s",
@@ -437,6 +433,7 @@ SDForest <- function(formula = NULL, data = NULL, x = NULL, y = NULL, nTree = 10
output$ooEnv_predictions <- ooEnv_predictions
}
+ RNGkind(ok[1])
class(output) <- 'SDForest'
output
}
diff --git a/R/SDTree.R b/R/SDTree.R
index 4d0a54e..6c63fc7 100644
--- a/R/SDTree.R
+++ b/R/SDTree.R
@@ -138,253 +138,8 @@ SDTree <- function(formula = NULL, data = NULL, x = NULL, y = NULL, max_leaves =
if(!is.null(mtry) && mtry < 1) stop('mtry must be larger than 0')
if(n < 2 * min_sample) stop('n must be at least 2 * min_sample')
if(max_candidates < 1) stop('max_candidates must be at least 1')
-
- # estimate spectral transformation
-
- if(!is.null(A)){
- if(is.null(gamma)) stop('gamma must be provided if A is provided')
- if(is.vector(A)) A <- matrix(A)
- if(!is.matrix(A)) stop('A must be a matrix')
- if(nrow(A) != n) stop('A must have n rows')
- Wf <- get_Wf(A, gamma)
- }else {
- Wf <- function(v) v
- }
-
- if(is.null(Qf)){
- if(!is.null(A)){
- Qf <- function(v) get_Qf(Wf(X), Q_type, trim_quantile, q_hat, Q_scale)(Wf(v))
- }else{
- Qf <- get_Qf(X, Q_type, trim_quantile, q_hat, Q_scale)
- }
- }else{
- if(!is.function(Qf)) stop('Q must be a function')
- if(length(Qf(rnorm(n))) == n) stop('Q must map from n to n')
- }
-
- #selection of predictors
- if(!is.null(predictors)){
- if(is.character(predictors)){
- if(!all(predictors %in% colnames(X)))
- stop("predictors must either be numeric columne index or in colnames of X")
- predictors <- which(colnames(X) %in% predictors)
- }
- if(is.numeric(predictors)){
- if(!all(predictors > 0 & predictors <= ncol(X)))
- stop("predictors must either be numeric columne index or in colnames of X")
- }
- pred_names <- colnames(X)
- X <- matrix(X[, predictors], ncol = length(predictors))
- if(!is.null(pred_names)){
- colnames(X) <- pred_names[predictors]
- }
- }
-
- # number of covariates
- p <- ncol(X)
-
- if(!is.null(mtry) && mtry > p) stop('mtry must be at most p')
-
- # calculate first estimate
- E <- matrix(1, n, 1)
- E_tilde <- Qf(E)
- Ue <- E_tilde / sqrt(sum(E_tilde ** 2))
- Y_tilde <- Qf(Y)
-
- # solve linear model
- c_hat <- qr.coef(qr(E_tilde), Y_tilde)
- c_hat <- as.numeric(c_hat)
-
- loss_start <- as.numeric(sum((Y_tilde - c_hat) ** 2) / n)
- loss_temp <- loss_start
-
- # initialize tree
- treeInfo <- c("name", "left", "right", "j", "s", "value", "dloss",
- "res_dloss", "cp", "n_samples", "leaf")
- d <- length(treeInfo)
-
- tree <- matrix(0, ncol = d, nrow = 1, dimnames = list(NULL, treeInfo))
- tree[1, c("name", "value", "dloss", "cp", "n_samples", "leaf")] <-
- c(1, c_hat, loss_start, 10, n, 1)
- treeSize <- 1
-
- # memory for optimal splits
- memory <- list()
- potential_splits <- 1
-
- # variable importance
- var_imp <- rep(0, p)
- names(var_imp) <- colnames(X)
-
- after_mtry <- 0
-
- for(i in 1:max_leaves){
- # iterate over all possible splits every time
- # for slow but slightly better solution
- if(!fast){
- potential_splits <- 1:i
- to_small <- sapply(potential_splits,
- function(x){sum(E[, x]) < min_sample*2})
- potential_splits <- potential_splits[!to_small]
- }
-
- #iterate over new to estimate splits
- for(branch in potential_splits){
- # get samples in branch to evaluate
- E_branch <- E[, branch]
- index <- which(E_branch == 1)
- X_branch <- matrix(X[index, ], nrow = length(index))
-
- # get potential splitting candidates
- s <- find_s(X_branch, max_candidates = max_candidates)
- n_splits <- nrow(s)
-
- # remove splits resulting in to small leaves
- if(min_sample > 1) {
- s <- s[-c(0:(min_sample - 1), (n_splits - min_sample + 2):(n_splits+1)), ]
- }
- s <- matrix(s, ncol = p)
-
- optSplits <- lapply(1:p, function(j){
- s_j <- unique(s[, j])
- E_next <- lapply(s_j, function(si) {
- E_next <- matrix(0, nrow = n, ncol = 1)
- E_next[index[X_branch[, j] > si], ] <- 1
- if(sum(E_next) == 0)return(NULL)
- E_next
- })
- E_next <- do.call(cbind, E_next)
- if(is.null(E_next)) return(c(-10, j, 0, branch))
- U_next_prime <- Qf_temp(E_next, Ue, Qf)
- U_next_size <- colSums(U_next_prime ** 2)
- dloss <- as.numeric(crossprod(U_next_prime, Y_tilde))**2 / U_next_size
-
- opt <- which.max(unlist(dloss))
- c(dloss[[opt]], j, s_j[opt], branch)
- })
- memory[[branch]] <- do.call(rbind, optSplits)
- }
-
- if(i > after_mtry && !is.null(mtry)){
- Losses_dec <- lapply(memory, function(branch){
- branch[sample(1:p, mtry), ]})
- Losses_dec <- do.call(rbind, Losses_dec)
- }else {
- Losses_dec <- do.call(rbind, memory)
- }
-
- loc <- which.max(Losses_dec[, 1])
- best_branch <- Losses_dec[loc, 4]
- j <- Losses_dec[loc, 2]
- s <- Losses_dec[loc, 3]
-
- if(Losses_dec[loc, 1] <= 0){
- break
- }
-
- # divide observations in leaf
- index <- which(E[, best_branch] == 1)
- index_n_branches <- index[X[index, j] > s]
-
- # new indicator matrix
- E <- cbind(E, matrix(0, n, 1))
- E[index_n_branches, best_branch] <- 0
- E[index_n_branches, i+1] <- 1
-
- E_tilde_branch <- E_tilde[, best_branch]
- suppressWarnings({
- E_tilde[, best_branch] <- Qf(E[, best_branch])
- })
- E_tilde <- cbind(E_tilde, matrix(E_tilde_branch - E_tilde[, best_branch]))
-
- c_hat <- qr.coef(qr(E_tilde), Y_tilde)
-
- u_next_prime <- Qf_temp(E[, i + 1], Ue, Qf)
- Ue <- cbind(Ue, u_next_prime / sqrt(sum(u_next_prime ** 2)))
-
- # check if loss decrease is larger than minimum loss decrease
- # and if linear model could be estimated
- if(sum(is.na(as.numeric(c_hat))) > 0){
- warning('singulaer matrix QE, tree might be to large, consider increasing cp')
- break
- }
-
- loss_dec <- as.numeric(loss_temp - loss(Y_tilde, E_tilde %*% c_hat))
- loss_temp <- loss_temp - loss_dec
-
- if(loss_dec <= cp * loss_start){
- break
- }
- # add loss decrease to variable importance
- var_imp[j] <- var_imp[j] + loss_dec
-
- # add space for the two new leaves
- tree <- rbind(tree, matrix(0, nrow = 2, ncol = d))
-
- # select leaf to split
- leaves <- tree[, "leaf"] == 1
- toSplit <- leaves & (tree[, "name"] == best_branch)
- if(sum(toSplit) != 1) stop("Tries to split more than one leaf")
-
- # save split rule
- tree[toSplit, c("left", "right", "j", "s", "res_dloss", "leaf")] <-
- c(treeSize + 1, treeSize + 2, j, s, loss_dec, 2)
-
- # add new leaves
- tree[treeSize + 1, c("name", "dloss", "cp", "n_samples", "leaf")] <-
- c(tree[toSplit, "name"], loss_dec, loss_dec / loss_start, sum(E[, best_branch] == 1), 1)
- tree[treeSize + 2, c("name", "dloss", "cp", "n_samples", "leaf")] <-
- c(i + 1, loss_dec, loss_dec / loss_start, sum(E[, i + 1] == 1), 1)
- treeSize <- treeSize + 2
-
- # add estimates to tree leaves
- c_hat <- as.numeric(c_hat)
- # access leaf estimates by leaf names (i.e. columns of E)
- tree[tree[, "leaf"] == 1, "value"] <- c_hat[tree[tree[, "leaf"] == 1, "name"]]
-
- # the two new partitions need to be checked for optimal splits in next iteration
- potential_splits <- c(best_branch, i + 1)
-
- # a partition with less than min_sample observations or unique samples
- # are not available for further splits
- to_small <- sapply(potential_splits, function(x){
- new_samples <- nrow(unique(matrix(X[as.logical(E[, x]),], nrow = sum(E[, x]))))
- if(is.null(new_samples)) new_samples <- 0
- (new_samples < min_sample * 2)
- })
- if(sum(to_small) > 0){
- for(el in potential_splits[to_small]){
- # to small partitions cannot decrease the loss
- memory[[el]] <- matrix(0, p, 4)
- }
- potential_splits <- potential_splits[!to_small]
- }
- }
-
- if(i == max_leaves){
- warning('maximum number of iterations was reached, consider increasing m!')
- }
-
- # predict the test set
- f_X_hat <- traverse_tree(tree, X)
-
- var_names <- colnames(data.frame(X))
- names(var_imp) <- var_names
-
- # cp max of all splits after
- new_cp <- getCp_max(tree)
- tree[new_cp[[2]], "cp"] <- new_cp[[1]]
- # use max cp over siblings to ensure binary tree
- for(i in 1:nrow(tree)){
- if(tree[i, c("j")] != 0){
- tree[tree[i, c("left", "right")], "cp"] <-
- max(tree[tree[i, c("left", "right")], "cp"])
- }
- }
-
- res <- list(predictions = f_X_hat, tree = tree,
- var_names = var_names, var_importance = var_imp)
- class(res) <- 'SDTree'
- res
+ return(estimate_tree(boot_index = NULL, Y, X, A, max_leaves, cp, min_sample, mtry, fast,
+ Q_type, trim_quantile, q_hat, Qf, gamma, max_candidates,
+ Q_scale, predictors))
}
diff --git a/R/partDependence.R b/R/partDependence.R
index c514ade..5c87a1f 100644
--- a/R/partDependence.R
+++ b/R/partDependence.R
@@ -17,8 +17,11 @@
#' If NULL, tries to extract the dataset from the model object.
#' @param subSample Number of samples to draw from the original data for the empirical
#' partial dependence. If NULL, all the observations are used.
-#' @param mc.cores Number of cores to use for parallel computation.
-#' Parallel computing is only supported for unix.
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/)
+#' @param mc.cores Number of cores to use for parallel computation `vignette("Runtime")`.
+#' The `future` package is used for parallel processing.
+#' To use custom processing plans mc.cores has to be <= 1, see [`future` package](https://future.futureverse.org/).
#' @return An object of class \code{partDependence} containing
#' \item{preds_mean}{The average prediction for each value of the variable of interest.}
#' \item{x_seq}{The sequence of values for the variable of interest.}
@@ -34,7 +37,8 @@
#' plot(pd)
#' @seealso \code{\link{SDForest}}, \code{\link{SDTree}}
#' @export
-partDependence <- function(object, j, X = NULL, subSample = NULL, mc.cores = 1){
+partDependence <- function(object, j, X = NULL, subSample = NULL,
+ verbose = TRUE, mc.cores = 1){
j_name <- j
if(is.null(X)){
@@ -58,21 +62,22 @@ partDependence <- function(object, j, X = NULL, subSample = NULL, mc.cores = 1){
x_seq <- seq(min(X[, j]), max(X[, j]), length.out = 100)
- if(mc.cores > 1){
- preds <- parallel::mclapply(x_seq, function(x){
- X_new <- X
- X_new[, j] <- x
- pred <- predict(object, newdata = X_new)
- return(pred)
- }, mc.cores = mc.cores)
- }else{
- preds <- pbapply::pblapply(x_seq, function(x){
- X_new <- X
- X_new[, j] <- x
- pred <- predict(object, newdata = X_new)
- return(pred)
- })
- }
+ progressr::with_progress({
+ p <- progressr::progressor(along = x_seq, enable = verbose)
+ if(mc.cores > 1){
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = mc.cores), local = TRUE)
+ }
+ preds <- future.apply::future_lapply(future.seed = TRUE,
+ X = x_seq,
+ function(x){
+ X_new <- X
+ X_new[, j] <- x
+ pred <- predict(object, newdata = X_new)
+ p(sprintf("x=%g", x))
+ return(pred)
+ })
+ })
preds <- do.call(rbind, preds)
preds_mean <- rowMeans(preds)
diff --git a/R/paths.R b/R/paths.R
index 406553b..9a9ad96 100644
--- a/R/paths.R
+++ b/R/paths.R
@@ -58,6 +58,11 @@ regPath.SDTree <- function(object, cp_seq = NULL, ...){
#' @param X The training data, if NULL the data from the forest object is used.
#' @param Y The training response variable, if NULL the data from the forest object is used.
#' @param Q The transformation matrix, if NULL the data from the forest object is used.
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
+#' @param mc.cores Number of cores to use for parallel computation `vignette("Runtime")`.
+#' The `future` package is used for parallel processing.
+#' To use custom processing plans mc.cores has to be <= 1, see [`future` package](https://future.futureverse.org/).
#' @param ... Further arguments passed to or from other methods.
#' @return An object of class \code{paths} containing
#' \item{cp}{The sequence of complexity parameters.}
@@ -84,21 +89,29 @@ regPath.SDTree <- function(object, cp_seq = NULL, ...){
#'
#' @export
regPath.SDForest <- function(object, cp_seq = NULL, X = NULL, Y = NULL, Q = NULL,
- ...){
+ verbose = TRUE, mc.cores = 1, ...){
if(is.null(cp_seq)) cp_seq <- get_cp_seq(object)
cp_seq <- sort(cp_seq)
- res <- pbapply::pblapply(cp_seq, function(cp){
- pruned_object <- prune(object, cp, X, Y, Q, pred = FALSE)
- return(list(var_importance = pruned_object$var_importance,
- oob_SDloss = pruned_object$oob_SDloss,
- oob_loss = pruned_object$oob_loss))})
+ progressr::with_progress({
+ p <- progressr::progressor(along = cp_seq, enable = verbose)
+ if(mc.cores > 1){
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = mc.cores), local = TRUE)
+ }
+ res <- future.apply::future_lapply(future.seed = TRUE,
+ X = cp_seq,
+ function(cp){
+ pruned_object <- prune(object, cp, X, Y, Q, pred = FALSE)
+ p(sprintf("cp=%g", cp))
+ return(list(var_importance = pruned_object$var_importance,
+ oob_SDloss = pruned_object$oob_SDloss,
+ oob_loss = pruned_object$oob_loss))})
+ })
- #varImp_path <- t(sapply(res, function(x)x$var_importance))
varImp_path <- do.call(rbind, lapply(res, function(x)x$var_importance))
colnames(varImp_path) <- object$var_names
- #loss_path <- t(sapply(res, function(x) c(x$oob_SDloss, x$oob_loss)))
loss_path <- do.call(rbind, lapply(res, function(x) c(x$oob_SDloss, x$oob_loss)))
colnames(loss_path) <- c('oob SDE', 'oob MSE')
paths <- list(cp = cp_seq, varImp_path = varImp_path, loss_path = loss_path,
@@ -125,6 +138,8 @@ stabilitySelection <- function(object, ...) UseMethod('stabilitySelection')
#' @param object an SDForest object
#' @param cp_seq A sequence of complexity parameters.
#' If NULL, the sequence is calculated automatically using only relevant values.
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
#' @param ... Further arguments passed to or from other methods.
#' @return An object of class \code{paths} containing
#' \item{cp}{The sequence of complexity parameters.}
@@ -145,11 +160,22 @@ stabilitySelection <- function(object, ...) UseMethod('stabilitySelection')
#' plot(paths, plotly = TRUE)
#' }
#' @export
-stabilitySelection.SDForest <- function(object, cp_seq = NULL, ...){
+stabilitySelection.SDForest <- function(object, cp_seq = NULL,
+ verbose = TRUE, ...){
if(is.null(cp_seq)) cp_seq <- get_cp_seq(object)
cp_seq <- sort(cp_seq)
-
- imp <- pbapply::pblapply(object$forest, function(x)regPath(x, cp_seq)$varImp_path > 0)
+
+ progressr::with_progress({
+ p <- progressr::progressor(along = 1:length(object$forest), enable = verbose)
+ imp <- lapply(1:length(object$forest),
+ function(i){
+ path <- regPath(object$forest[[i]],
+ cp_seq, vebose = FALSE)$varImp_path > 0
+ p()
+ path
+ })
+ })
+
imp <- lapply(imp, function(x)matrix(as.numeric(x), ncol = ncol(x)))
imp <- Reduce('+', imp) / length(object$forest)
diff --git a/R/plot.R b/R/plot.R
index 1eb2f67..6f79785 100644
--- a/R/plot.R
+++ b/R/plot.R
@@ -44,7 +44,7 @@ plot.SDTree <- function(x, main = "", digits = 2, digits_decisions = 2,
if(weighted){
#res_dloss <- edges$res_dloss
#re scale edge weights
- edges$res_dloss <- (edges$res_dloss - min(edges$res_dloss)) / (max(edges$res_dloss) - min(edges$res_dloss)) * 2 + 0.5
+ edges$res_dloss <- (edges$res_dloss - min(edges$res_dloss)) / ((max(edges$res_dloss) - min(edges$res_dloss)) * 2 + 0.1) + 0.5
}else{
edges$res_dloss <- 0.5
}
@@ -66,7 +66,8 @@ plot.SDTree <- function(x, main = "", digits = 2, digits_decisions = 2,
ggplot2::annotate("segment", x = nLeaves*1.02, y = depth*0.98, xend = nLeaves*1.1, yend = depth*0.98,
arrow = ggplot2::arrow(length = ggplot2::unit(0.1, "inches")), color = "black") +
ggplot2::annotate("text", x = nLeaves * 1.05, y = depth, label = "no", size = 4) +
- ggplot2::ggtitle(main)
+ ggplot2::ggtitle(main) +
+ ggplot2::ylim(-0.1*depth, depth*1.1)
}
#' Plot performance of SDForest against number of trees
@@ -75,6 +76,8 @@ plot.SDTree <- function(x, main = "", digits = 2, digits_decisions = 2,
#' not stabilize one can fit another SDForest and merge the two.
#' @author Markus Ulmer
#' @param x Fitted object of class \code{SDForest}.
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
#' @param ... Further arguments passed to or from other methods.
#' @return A ggplot object
#' @seealso \code{\link{SDForest}}
@@ -85,26 +88,31 @@ plot.SDTree <- function(x, main = "", digits = 2, digits_decisions = 2,
#' model <- SDForest(x = X, y = y, Q_type = 'no_deconfounding', cp = 0.5, nTree = 500)
#' plot(model)
#' @export
-plot.SDForest <- function(x, ...){
+plot.SDForest <- function(x, verbose = TRUE, ...){
Y_ <- x$Q(x$Y)
# iterate over observations
- preds <- pbapply::pblapply(1:length(x$Y), function(i){
- if(length(x$oob_ind[[i]]) == 0){
- return(NA)
- }
- xi <- matrix(x$X[i, ], nrow = 1)
-
- # predict for each tree
- pred <- rep(NA, length(x$forest))
- model_idx <- x$oob_ind[[i]]
- model_idx <- model_idx[model_idx <= length(x$forest)]
- predictions <- sapply(model_idx, function(model){
- traverse_tree(x$forest[[model]]$tree, xi)
+ progressr::with_progress({
+ p <- progressr::progressor(along = 1:length(x$Y), enable = verbose)
+ preds <- lapply(1:length(x$Y), function(i){
+ if(length(x$oob_ind[[i]]) == 0){
+ return(NA)
+ }
+ xi <- matrix(x$X[i, ], nrow = 1)
+
+ # predict for each tree
+ pred <- rep(NA, length(x$forest))
+ model_idx <- x$oob_ind[[i]]
+ model_idx <- model_idx[model_idx <= length(x$forest)]
+ predictions <- sapply(model_idx, function(model){
+ traverse_tree(x$forest[[model]]$tree, xi)
+ })
+ pred[model_idx] <- predictions
+ p()
+ pred
})
- pred[model_idx] <- predictions
- pred
})
+
preds <- do.call(rbind, preds)
diff --git a/R/predict.R b/R/predict.R
index 033acd1..b8c8483 100644
--- a/R/predict.R
+++ b/R/predict.R
@@ -35,8 +35,11 @@ predict.SDTree <- function(object, newdata, ...){
#' @param object Fitted object of class \code{SDForest}.
#' @param newdata New test data of class \code{data.frame} containing
#' the covariates for which to predict the response.
-#' @param mc.cores Number of cores to use for parallel processing,
-#' if \code{mc.cores > 1} the trees predict in parallel.
+#' @param mc.cores Number of cores to use for parallel computation `vignette("Runtime")`.
+#' The `future` package is used for parallel processing.
+#' To use custom processing plans mc.cores has to be <= 1, see [`future` package](https://future.futureverse.org/).
+#' @param verbose If \code{TRUE} progress updates are shown using the `progressr` package.
+#' To customize the progress bar, see [`progressr` package](https://progressr.futureverse.org/articles/progressr-intro.html)
#' @param ... Further arguments passed to or from other methods.
#' @return A vector of predictions for the new data.
#' @examples
@@ -48,7 +51,7 @@ predict.SDTree <- function(object, newdata, ...){
#' predict(model, newdata = data.frame(X))
#' @seealso \code{\link{SDForest}}
#' @export
-predict.SDForest <- function(object, newdata, mc.cores = 1, ...){
+predict.SDForest <- function(object, newdata, mc.cores = 1, verbose = FALSE, ...){
# predict function for the spectral deconfounded random forest
# using the mean over all trees as the prediction
# check data type
@@ -62,9 +65,9 @@ predict.SDForest <- function(object, newdata, mc.cores = 1, ...){
if(any(is.na(X))) stop('X must not contain missing values')
- worker_fun <- function(tree){
+ worker_fun <- function(i){
preds_i <- tryCatch({
- preds <- traverse_tree(tree[["tree"]], X)
+ preds <- traverse_tree(object$forest[[i]][["tree"]], X)
list(ok = TRUE, preds = preds)
}, error = function(e) {
list(ok = FALSE, error = conditionMessage(e))
@@ -74,21 +77,14 @@ predict.SDForest <- function(object, newdata, mc.cores = 1, ...){
})
preds_i
}
+
if(mc.cores > 1){
- if(Sys.info()[["sysname"]] == "Linux"){
- preds_list <- parallel::mclapply(object$forest,
- worker_fun,
- mc.cores = mc.cores)
- }else{
- future::plan('multisession', workers = mc.cores)
- preds_list <- future.apply::future_lapply(future.seed = TRUE,
- X = object$forest,
- worker_fun)
- }
- }else{
- preds_list <- pbapply::pblapply(object$forest, worker_fun)
+ plan <- if (parallelly::supportsMulticore()) "multicore" else "multisession"
+ with(future::plan(plan, workers = mc.cores), local = TRUE)
}
-
+ preds_list <- future.apply::future_lapply(future.seed = TRUE,
+ X = 1:length(object$forest),
+ worker_fun)
#check worker statuses
failed_workers <- which(vapply(preds_list, function(z) !isTRUE(z$ok), logical(1)))
if (length(failed_workers) > 0) {
diff --git a/R/utility.R b/R/utility.R
index 754a820..d7ff1a5 100644
--- a/R/utility.R
+++ b/R/utility.R
@@ -82,7 +82,7 @@ split_names <- function(node, var_names = NULL, digits = 2){
}
-# finds all the reasonable spliting points in a data matrix
+# finds all the reasonable splitting points in a data matrix
find_s <- function(X, max_candidates = 100){
p <- ncol(X)
if(p == 1){
@@ -268,4 +268,266 @@ Bbasis <- function(x, breaks){
return(Bx)
}
+estimate_tree <- function(boot_index, Y, X, A, max_leaves, cp, min_sample, mtry, fast,
+ Q_type, trim_quantile, q_hat, Qf, gamma, max_candidates,
+ Q_scale, predictors){
+ if(is.null(boot_index)){
+ boot_index <- 1:nrow(X)
+ tree_in_forest <- FALSE
+ }else{
+ tree_in_forest <- TRUE
+ }
+ n <- length(boot_index)
+
+ # estimate spectral transformation
+ if(!is.null(A)){
+ if(is.null(gamma)) stop('gamma must be provided if A is provided')
+ if(is.vector(A)) A <- matrix(A)
+ if(!is.matrix(A)) stop('A must be a matrix')
+ if(nrow(A) != nrow(X)) stop('A must have n rows')
+ Wf <- get_Wf(matrix(A[boot_index, ], ncol = ncol(A)), gamma)
+ }else {
+ Wf <- function(v) v
+ }
+ if(is.null(Qf)){
+ if(!is.null(A)){
+ Qf <- function(v) get_Qf(Wf(X[boot_index, ]), Q_type, trim_quantile, q_hat, Q_scale)(Wf(v))
+ }else{
+ Qf <- get_Qf(X[boot_index, ], Q_type, trim_quantile, q_hat, Q_scale)
+ }
+ }else{
+ if(!is.function(Qf)) stop('Q must be a function')
+ if(length(Qf(rnorm(n))) == n) stop('Q must map from n to n')
+ }
+
+ #selection of predictors
+ if(!is.null(predictors)){
+ if(is.character(predictors)){
+ if(!all(predictors %in% colnames(X)))
+ stop("predictors must either be numeric columne index or in colnames of X")
+ predictors <- which(colnames(X) %in% predictors)
+ }
+ if(is.numeric(predictors)){
+ if(!all(predictors > 0 & predictors <= ncol(X)))
+ stop("predictors must either be numeric columne index or in colnames of X")
+ }
+ pred_names <- colnames(X)
+ X <- matrix(X[, predictors], ncol = length(predictors))
+ if(!is.null(pred_names)){
+ colnames(X) <- pred_names[predictors]
+ }
+ }
+
+ # number of covariates
+ p <- ncol(X)
+ if(!is.null(mtry) && mtry > p) stop('mtry must be at most p')
+
+ # calculate first estimate
+ E <- matrix(1, n, 1)
+ E_tilde <- Qf(E)
+ Ue <- E_tilde / sqrt(sum(E_tilde ** 2))
+ Y_tilde <- Qf(Y[boot_index])
+
+ # solve linear model
+ c_hat <- qr.coef(qr(E_tilde), Y_tilde)
+ c_hat <- as.numeric(c_hat)
+
+ loss_start <- as.numeric(sum((Y_tilde - c_hat) ** 2) / n)
+ loss_temp <- loss_start
+
+ # initialize tree
+ treeInfo <- c("name", "left", "right", "j", "s", "value", "dloss",
+ "res_dloss", "cp", "n_samples", "leaf")
+ d <- length(treeInfo)
+
+ tree <- matrix(0, ncol = d, nrow = 1, dimnames = list(NULL, treeInfo))
+ tree[1, c("name", "value", "dloss", "cp", "n_samples", "leaf")] <-
+ c(1, c_hat, loss_start, 10, n, 1)
+ treeSize <- 1
+
+ # memory for optimal splits
+ memory <- list()
+ potential_splits <- 1
+
+ # variable importance
+ var_imp <- rep(0, p)
+ names(var_imp) <- colnames(X)
+
+ after_mtry <- 0
+
+ for(i in 1:max_leaves){
+ # iterate over all possible splits every time
+ # for slow but slightly better solution
+ if(!fast){
+ potential_splits <- 1:i
+ to_small <- sapply(potential_splits,
+ function(x){sum(E[, x]) < min_sample*2})
+ potential_splits <- potential_splits[!to_small]
+ }
+
+ #iterate over new to estimate splits
+ for(branch in potential_splits){
+ # get samples in branch to evaluate
+ E_branch <- E[, branch]
+ index <- which(E_branch == 1)
+ X_branch <- matrix(X[boot_index[index], ], nrow = length(index))
+
+ # get potential splitting candidates
+ s <- find_s(X_branch, max_candidates = max_candidates)
+ n_splits <- nrow(s)
+
+ # remove splits resulting in to small leaves
+ if(min_sample > 1) {
+ s <- s[-c(0:(min_sample - 1), (n_splits - min_sample + 2):(n_splits+1)), ]
+ }
+ s <- matrix(s, ncol = p)
+
+ optSplits <- lapply(1:p, function(j){
+ s_j <- unique(s[, j])
+ E_next <- lapply(s_j, function(si) {
+ E_next <- matrix(0, nrow = n, ncol = 1)
+ E_next[index[X_branch[, j] > si], ] <- 1
+ if(sum(E_next) == 0)return(NULL)
+ E_next
+ })
+ E_next <- do.call(cbind, E_next)
+ if(is.null(E_next)) return(c(-10, j, 0, branch))
+ U_next_prime <- Qf_temp(E_next, Ue, Qf)
+ U_next_size <- colSums(U_next_prime ** 2)
+ dloss <- as.numeric(crossprod(U_next_prime, Y_tilde))**2 / U_next_size
+
+ opt <- which.max(unlist(dloss))
+ c(dloss[[opt]], j, s_j[opt], branch)
+ })
+ memory[[branch]] <- do.call(rbind, optSplits)
+ }
+
+ if(i > after_mtry && !is.null(mtry)){
+ Losses_dec <- lapply(memory, function(branch){
+ branch[sample(1:p, mtry), ]})
+ Losses_dec <- do.call(rbind, Losses_dec)
+ }else {
+ Losses_dec <- do.call(rbind, memory)
+ }
+
+ loc <- which.max(Losses_dec[, 1])
+ best_branch <- Losses_dec[loc, 4]
+ j <- Losses_dec[loc, 2]
+ s <- Losses_dec[loc, 3]
+
+ if(Losses_dec[loc, 1] <= 0){
+ break
+ }
+
+ # divide observations in leaf
+ index <- which(E[, best_branch] == 1)
+ index_n_branches <- index[X[boot_index[index], j] > s]
+
+ # new indicator matrix
+ E <- cbind(E, matrix(0, n, 1))
+ E[index_n_branches, best_branch] <- 0
+ E[index_n_branches, i+1] <- 1
+
+ E_tilde_branch <- E_tilde[, best_branch]
+ suppressWarnings({
+ E_tilde[, best_branch] <- Qf(E[, best_branch])
+ })
+ E_tilde <- cbind(E_tilde, matrix(E_tilde_branch - E_tilde[, best_branch]))
+
+ c_hat <- qr.coef(qr(E_tilde), Y_tilde)
+
+ u_next_prime <- Qf_temp(E[, i + 1], Ue, Qf)
+ Ue <- cbind(Ue, u_next_prime / sqrt(sum(u_next_prime ** 2)))
+
+ # check if loss decrease is larger than minimum loss decrease
+ # and if linear model could be estimated
+ if(sum(is.na(as.numeric(c_hat))) > 0){
+ warning('singulaer matrix QE, tree might be to large, consider increasing cp')
+ break
+ }
+
+ loss_dec <- as.numeric(loss_temp - loss(Y_tilde, E_tilde %*% c_hat))
+ loss_temp <- loss_temp - loss_dec
+
+ if(loss_dec <= cp * loss_start){
+ break
+ }
+ # add loss decrease to variable importance
+ var_imp[j] <- var_imp[j] + loss_dec
+
+ # add space for the two new leaves
+ tree <- rbind(tree, matrix(0, nrow = 2, ncol = d))
+
+ # select leaf to split
+ leaves <- tree[, "leaf"] == 1
+ toSplit <- leaves & (tree[, "name"] == best_branch)
+ if(sum(toSplit) != 1) stop("Tries to split more than one leaf")
+
+ # save split rule
+ tree[toSplit, c("left", "right", "j", "s", "res_dloss", "leaf")] <-
+ c(treeSize + 1, treeSize + 2, j, s, loss_dec, 2)
+
+ # add new leaves
+ tree[treeSize + 1, c("name", "dloss", "cp", "n_samples", "leaf")] <-
+ c(tree[toSplit, "name"], loss_dec, loss_dec / loss_start, sum(E[, best_branch] == 1), 1)
+ tree[treeSize + 2, c("name", "dloss", "cp", "n_samples", "leaf")] <-
+ c(i + 1, loss_dec, loss_dec / loss_start, sum(E[, i + 1] == 1), 1)
+ treeSize <- treeSize + 2
+
+ # add estimates to tree leaves
+ c_hat <- as.numeric(c_hat)
+ # access leaf estimates by leaf names (i.e. columns of E)
+ tree[tree[, "leaf"] == 1, "value"] <- c_hat[tree[tree[, "leaf"] == 1, "name"]]
+
+ # the two new partitions need to be checked for optimal splits in next iteration
+ potential_splits <- c(best_branch, i + 1)
+
+ # a partition with less than min_sample observations or unique samples
+ # are not available for further splits
+ to_small <- sapply(potential_splits, function(x){
+ new_samples <- nrow(unique(matrix(X[boot_index[as.logical(E[, x])],], nrow = sum(E[, x]))))
+ if(is.null(new_samples)) new_samples <- 0
+ (new_samples < min_sample * 2)
+ })
+ if(sum(to_small) > 0){
+ for(el in potential_splits[to_small]){
+ # to small partitions cannot decrease the loss
+ memory[[el]] <- matrix(0, p, 4)
+ }
+ potential_splits <- potential_splits[!to_small]
+ }
+ }
+
+ if(i == max_leaves){
+ warning('maximum number of iterations was reached, consider increasing m!')
+ }
+
+ # predict the test set
+ if(tree_in_forest){
+ f_X_hat <- NULL
+ }else{
+ f_X_hat <- traverse_tree(tree, X)
+ }
+
+
+ var_names <- colnames(data.frame(X))
+ names(var_imp) <- var_names
+
+ # cp max of all splits after
+ new_cp <- getCp_max(tree)
+ tree[new_cp[[2]], "cp"] <- new_cp[[1]]
+
+ # use max cp over siblings to ensure binary tree
+ for(i in 1:nrow(tree)){
+ if(tree[i, c("j")] != 0){
+ tree[tree[i, c("left", "right")], "cp"] <-
+ max(tree[tree[i, c("left", "right")], "cp"])
+ }
+ }
+
+ res <- list(predictions = f_X_hat, tree = tree,
+ var_names = var_names, var_importance = var_imp)
+ class(res) <- 'SDTree'
+ res
+}
diff --git a/README.Rmd b/README.Rmd
index e2f3708..0ade575 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -75,7 +75,7 @@ You can also estimate just one Spectrally Deconfounded Regression Tree using the
```{r SDTree, fig.height=7}
Tree <- SDTree(Y ~ ., train_data, cp = 0.01)
-plot(Tree)
+#plot(Tree)
```
Or you can estimate a Spectrally Deconfounded Additive Model, with theoretical guarantees, using the `SDAM` function. See also the article [SDAM](https://www.markus-ulmer.ch/SDModels/articles/SDAM.html).
diff --git a/README.md b/README.md
index 62bc28f..ba5d343 100644
--- a/README.md
+++ b/README.md
@@ -76,8 +76,8 @@ fit
#>
#> Number of trees: 100
#> Number of covariates: 50
-#> OOB loss: 0.1554798
-#> OOB spectral loss: 0.05246865
+#> OOB loss: 0.1617913
+#> OOB spectral loss: 0.05095329
```
You can also estimate just one Spectrally Deconfounded Regression Tree
@@ -86,24 +86,20 @@ using the `SDTree` function. See also the article
``` r
Tree <- SDTree(Y ~ ., train_data, cp = 0.01)
-plot(Tree)
+#plot(Tree)
```
-
-
Or you can estimate a Spectrally Deconfounded Additive Model, with
theoretical guarantees, using the `SDAM` function. See also the article
[SDAM](https://www.markus-ulmer.ch/SDModels/articles/SDAM.html).
``` r
model <- SDAM(Y ~ ., train_data)
-#> [1] "Initial cross-validation"
-#> [1] "Second stage cross-validation"
model
#> SDAM result
#>
#> Number of covariates: 50
-#> Number of active covariates: 4
+#> Number of active covariates: 3
```