Skip to content

Commit 91a455c

Browse files
authored
Merge pull request #143 from StochasticTree/cran-comments-0.1.0-round1
Addressing first round of CRAN comments for 0.1.0 release
2 parents d8918b4 + 6a2d229 commit 91a455c

34 files changed

+334
-331
lines changed

DESCRIPTION

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,17 @@ Authors@R:
88
person("Jared", "Murray", role = "aut"),
99
person("Carlos", "Carvalho", role = "aut"),
1010
person("Jingyu", "He", role = "aut"),
11-
person("stochtree contributors", role = c("cph"))
11+
person("Pedro", "Lima", role = "ctb"),
12+
person("stochtree", "contributors", role = c("cph")),
13+
person("Eigen", "contributors", role = c("cph"), comment = "C++ source uses the Eigen library for matrix operations, see inst/COPYRIGHTS"),
14+
person("xgboost", "contributors", role = c("cph"), comment = "C++ tree code and related operations include or are inspired by code from the xgboost library, see inst/COPYRIGHTS"),
15+
person("treelite", "contributors", role = c("cph"), comment = "C++ tree code and related operations include or are inspired by code from the treelite library, see inst/COPYRIGHTS"),
16+
person("Microsoft", "Corporation", role = c("cph"), comment = "C++ I/O and various project structure code include or are inspired by code from the LightGBM library, which is a copyright of Microsoft, see inst/COPYRIGHTS"),
17+
person("Niels", "Lohmann", role = c("cph"), comment = "C++ source uses the JSON for Modern C++ library for JSON operations, see inst/COPYRIGHTS"),
18+
person("Daniel", "Lemire", role = c("cph"), comment = "C++ source uses the fast_double_parser library internally, see inst/COPYRIGHTS"),
19+
person("Victor", "Zverovich", role = c("cph"), comment = "C++ source uses the fmt library internally, see inst/COPYRIGHTS")
1220
)
21+
Copyright: Copyright details for stochtree's C++ dependencies, which are vendored along with the core stochtree source code, are detailed in inst/COPYRIGHTS
1322
Description: Flexible stochastic tree ensemble software.
1423
Robust implementations of Bayesian Additive Regression Trees (BART)
1524
Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>

R/bart.R

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@
9999
#' y_train <- y[train_inds]
100100
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
101101
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
102-
#' plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
103-
#' abline(0,1,col="red",lty=3,lwd=3)
104102
bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL,
105103
rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL,
106104
rfx_group_ids_test = NULL, rfx_basis_test = NULL,
@@ -110,12 +108,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
110108
variance_forest_params = list()) {
111109
# Update general BART parameters
112110
general_params_default <- list(
113-
cutpoint_grid_size = 100, standardize = T,
114-
sample_sigma2_global = T, sigma2_global_init = NULL,
111+
cutpoint_grid_size = 100, standardize = TRUE,
112+
sample_sigma2_global = TRUE, sigma2_global_init = NULL,
115113
sigma2_global_shape = 0, sigma2_global_scale = 0,
116114
variable_weights = NULL, random_seed = -1,
117-
keep_burnin = F, keep_gfr = F, keep_every = 1,
118-
num_chains = 1, verbose = F
115+
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
116+
num_chains = 1, verbose = FALSE
119117
)
120118
general_params_updated <- preprocessParams(
121119
general_params_default, general_params
@@ -125,7 +123,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
125123
mean_forest_params_default <- list(
126124
num_trees = 200, alpha = 0.95, beta = 2.0,
127125
min_samples_leaf = 5, max_depth = 10,
128-
sample_sigma2_leaf = T, sigma2_leaf_init = NULL,
126+
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
129127
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
130128
keep_vars = NULL, drop_vars = NULL
131129
)
@@ -197,7 +195,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
197195
}
198196

199197
# Override keep_gfr if there are no MCMC samples
200-
if (num_mcmc == 0) keep_gfr <- T
198+
if (num_mcmc == 0) keep_gfr <- TRUE
201199

202200
# Check if previous model JSON is provided and parse it if so
203201
has_prev_model <- !is.null(previous_model_json)
@@ -238,10 +236,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
238236
}
239237

240238
# Determine whether conditional mean, variance, or both will be modeled
241-
if (num_trees_variance > 0) include_variance_forest = T
242-
else include_variance_forest = F
243-
if (num_trees_mean > 0) include_mean_forest = T
244-
else include_mean_forest = F
239+
if (num_trees_variance > 0) include_variance_forest = TRUE
240+
else include_variance_forest = FALSE
241+
if (num_trees_mean > 0) include_mean_forest = TRUE
242+
else include_mean_forest = FALSE
245243

246244
# Set the variance forest priors if not set
247245
if (include_variance_forest) {
@@ -253,7 +251,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
253251
}
254252

255253
# Override tau sampling if there is no mean forest
256-
if (!include_mean_forest) sample_sigma_leaf <- F
254+
if (!include_mean_forest) sample_sigma_leaf <- FALSE
257255

258256
# Variable weight preprocessing (and initialization if necessary)
259257
if (is.null(variable_weights)) {
@@ -388,19 +386,19 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
388386
}
389387

390388
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
391-
has_rfx <- F
392-
has_rfx_test <- F
389+
has_rfx <- FALSE
390+
has_rfx_test <- FALSE
393391
if (!is.null(rfx_group_ids_train)) {
394392
group_ids_factor <- factor(rfx_group_ids_train)
395393
rfx_group_ids_train <- as.integer(group_ids_factor)
396-
has_rfx <- T
394+
has_rfx <- TRUE
397395
if (!is.null(rfx_group_ids_test)) {
398396
group_ids_factor_test <- factor(rfx_group_ids_test, levels = levels(group_ids_factor))
399397
if (sum(is.na(group_ids_factor_test)) > 0) {
400398
stop("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train")
401399
}
402400
rfx_group_ids_test <- as.integer(group_ids_factor_test)
403-
has_rfx_test <- T
401+
has_rfx_test <- TRUE
404402
}
405403
}
406404

@@ -432,13 +430,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
432430
}
433431

434432
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
435-
has_basis_rfx <- F
433+
has_basis_rfx <- FALSE
436434
num_basis_rfx <- 0
437435
if (has_rfx) {
438436
if (is.null(rfx_basis_train)) {
439437
rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1)
440438
} else {
441-
has_basis_rfx <- T
439+
has_basis_rfx <- TRUE
442440
num_basis_rfx <- ncol(rfx_basis_train)
443441
}
444442
num_rfx_groups <- length(unique(rfx_group_ids_train))
@@ -520,40 +518,40 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
520518
# Unpack model type info
521519
if (leaf_model_mean_forest == 0) {
522520
leaf_dimension = 1
523-
is_leaf_constant = T
524-
leaf_regression = F
521+
is_leaf_constant = TRUE
522+
leaf_regression = FALSE
525523
} else if (leaf_model_mean_forest == 1) {
526524
stopifnot(has_basis)
527525
stopifnot(ncol(leaf_basis_train) == 1)
528526
leaf_dimension = 1
529-
is_leaf_constant = F
530-
leaf_regression = T
527+
is_leaf_constant = FALSE
528+
leaf_regression = TRUE
531529
} else if (leaf_model_mean_forest == 2) {
532530
stopifnot(has_basis)
533531
stopifnot(ncol(leaf_basis_train) > 1)
534532
leaf_dimension = ncol(leaf_basis_train)
535-
is_leaf_constant = F
536-
leaf_regression = T
533+
is_leaf_constant = FALSE
534+
leaf_regression = TRUE
537535
if (sample_sigma_leaf) {
538536
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.")
539-
sample_sigma_leaf <- F
537+
sample_sigma_leaf <- FALSE
540538
}
541539
}
542540

543541
# Data
544542
if (leaf_regression) {
545543
forest_dataset_train <- createForestDataset(X_train, leaf_basis_train)
546544
if (has_test) forest_dataset_test <- createForestDataset(X_test, leaf_basis_test)
547-
requires_basis <- T
545+
requires_basis <- TRUE
548546
} else {
549547
forest_dataset_train <- createForestDataset(X_train)
550548
if (has_test) forest_dataset_test <- createForestDataset(X_test)
551-
requires_basis <- F
549+
requires_basis <- FALSE
552550
}
553551
outcome_train <- createOutcome(resid_train)
554552

555553
# Random number generator (std::mt19937)
556-
if (is.null(random_seed)) random_seed = sample(1:10000,1,F)
554+
if (is.null(random_seed)) random_seed = sample(1:10000,1,FALSE)
557555
rng <- createCppRNG(random_seed)
558556

559557
# Sampling data structures
@@ -630,7 +628,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
630628
if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train))
631629
else init_values_mean_forest <- 0.
632630
active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest)
633-
active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, F)
631+
active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, FALSE)
634632
}
635633

636634
# Initialize the leaves of each tree in the variance forest
@@ -643,8 +641,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
643641
if (num_gfr > 0){
644642
for (i in 1:num_gfr) {
645643
# Keep all GFR samples at this stage -- remove from ForestSamples after MCMC
646-
# keep_sample <- ifelse(keep_gfr, T, F)
647-
keep_sample <- T
644+
# keep_sample <- ifelse(keep_gfr, TRUE, FALSE)
645+
keep_sample <- TRUE
648646
if (keep_sample) sample_counter <- sample_counter + 1
649647
# Print progress
650648
if (verbose) {
@@ -657,14 +655,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
657655
forest_model_mean$sample_one_iteration(
658656
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
659657
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
660-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
658+
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
661659
)
662660
}
663661
if (include_variance_forest) {
664662
forest_model_variance$sample_one_iteration(
665663
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
666664
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
667-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
665+
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
668666
)
669667
}
670668
if (sample_sigma_global) {
@@ -771,11 +769,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
771769
is_mcmc <- i > (num_gfr + num_burnin)
772770
if (is_mcmc) {
773771
mcmc_counter <- i - (num_gfr + num_burnin)
774-
if (mcmc_counter %% keep_every == 0) keep_sample <- T
775-
else keep_sample <- F
772+
if (mcmc_counter %% keep_every == 0) keep_sample <- TRUE
773+
else keep_sample <- FALSE
776774
} else {
777-
if (keep_burnin) keep_sample <- T
778-
else keep_sample <- F
775+
if (keep_burnin) keep_sample <- TRUE
776+
else keep_sample <- FALSE
779777
}
780778
if (keep_sample) sample_counter <- sample_counter + 1
781779
# Print progress
@@ -796,14 +794,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
796794
forest_model_mean$sample_one_iteration(
797795
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
798796
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
799-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
797+
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
800798
)
801799
}
802800
if (include_variance_forest) {
803801
forest_model_variance$sample_one_iteration(
804802
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
805803
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
806-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
804+
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
807805
)
808806
}
809807
if (sample_sigma_global) {
@@ -994,8 +992,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
994992
#' bart_model <- bart(X_train = X_train, y_train = y_train,
995993
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
996994
#' y_hat_test <- predict(bart_model, X_test)$y_hat
997-
#' plot(rowMeans(y_hat_test), y_test, xlab = "predicted", ylab = "actual")
998-
#' abline(0,1,col="red",lty=3,lwd=3)
999995
predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, ...){
1000996
# Preprocess covariates
1001997
if ((!is.data.frame(X)) && (!is.matrix(X))) {
@@ -1033,15 +1029,15 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
10331029
}
10341030

10351031
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
1036-
has_rfx <- F
1032+
has_rfx <- FALSE
10371033
if (!is.null(rfx_group_ids)) {
10381034
rfx_unique_group_ids <- object$rfx_unique_group_ids
10391035
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
10401036
if (sum(is.na(group_ids_factor)) > 0) {
10411037
stop("All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train")
10421038
}
10431039
rfx_group_ids <- as.integer(group_ids_factor)
1044-
has_rfx <- T
1040+
has_rfx <- TRUE
10451041
}
10461042

10471043
# Produce basis for the "intercept-only" random effects case
@@ -1557,8 +1553,6 @@ createBARTModelFromJsonFile <- function(json_filename){
15571553
#' bart_json <- saveBARTModelToJsonString(bart_model)
15581554
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
15591555
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
1560-
#' plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip,
1561-
#' xlab = "original", ylab = "roundtrip")
15621556
createBARTModelFromJsonString <- function(json_string){
15631557
# Load a `CppJson` object from string
15641558
bart_json <- createCppJsonString(json_string)

0 commit comments

Comments
 (0)