99
99
# ' y_train <- y[train_inds]
100
100
# ' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
101
101
# ' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
102
- # ' plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
103
- # ' abline(0,1,col="red",lty=3,lwd=3)
104
102
bart <- function (X_train , y_train , leaf_basis_train = NULL , rfx_group_ids_train = NULL ,
105
103
rfx_basis_train = NULL , X_test = NULL , leaf_basis_test = NULL ,
106
104
rfx_group_ids_test = NULL , rfx_basis_test = NULL ,
@@ -110,12 +108,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
110
108
variance_forest_params = list ()) {
111
109
# Update general BART parameters
112
110
general_params_default <- list (
113
- cutpoint_grid_size = 100 , standardize = T ,
114
- sample_sigma2_global = T , sigma2_global_init = NULL ,
111
+ cutpoint_grid_size = 100 , standardize = TRUE ,
112
+ sample_sigma2_global = TRUE , sigma2_global_init = NULL ,
115
113
sigma2_global_shape = 0 , sigma2_global_scale = 0 ,
116
114
variable_weights = NULL , random_seed = - 1 ,
117
- keep_burnin = F , keep_gfr = F , keep_every = 1 ,
118
- num_chains = 1 , verbose = F
115
+ keep_burnin = FALSE , keep_gfr = FALSE , keep_every = 1 ,
116
+ num_chains = 1 , verbose = FALSE
119
117
)
120
118
general_params_updated <- preprocessParams(
121
119
general_params_default , general_params
@@ -125,7 +123,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
125
123
mean_forest_params_default <- list (
126
124
num_trees = 200 , alpha = 0.95 , beta = 2.0 ,
127
125
min_samples_leaf = 5 , max_depth = 10 ,
128
- sample_sigma2_leaf = T , sigma2_leaf_init = NULL ,
126
+ sample_sigma2_leaf = TRUE , sigma2_leaf_init = NULL ,
129
127
sigma2_leaf_shape = 3 , sigma2_leaf_scale = NULL ,
130
128
keep_vars = NULL , drop_vars = NULL
131
129
)
@@ -197,7 +195,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
197
195
}
198
196
199
197
# Override keep_gfr if there are no MCMC samples
200
- if (num_mcmc == 0 ) keep_gfr <- T
198
+ if (num_mcmc == 0 ) keep_gfr <- TRUE
201
199
202
200
# Check if previous model JSON is provided and parse it if so
203
201
has_prev_model <- ! is.null(previous_model_json )
@@ -238,10 +236,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
238
236
}
239
237
240
238
# Determine whether conditional mean, variance, or both will be modeled
241
- if (num_trees_variance > 0 ) include_variance_forest = T
242
- else include_variance_forest = F
243
- if (num_trees_mean > 0 ) include_mean_forest = T
244
- else include_mean_forest = F
239
+ if (num_trees_variance > 0 ) include_variance_forest = TRUE
240
+ else include_variance_forest = FALSE
241
+ if (num_trees_mean > 0 ) include_mean_forest = TRUE
242
+ else include_mean_forest = FALSE
245
243
246
244
# Set the variance forest priors if not set
247
245
if (include_variance_forest ) {
@@ -253,7 +251,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
253
251
}
254
252
255
253
# Override tau sampling if there is no mean forest
256
- if (! include_mean_forest ) sample_sigma_leaf <- F
254
+ if (! include_mean_forest ) sample_sigma_leaf <- FALSE
257
255
258
256
# Variable weight preprocessing (and initialization if necessary)
259
257
if (is.null(variable_weights )) {
@@ -388,19 +386,19 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
388
386
}
389
387
390
388
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
391
- has_rfx <- F
392
- has_rfx_test <- F
389
+ has_rfx <- FALSE
390
+ has_rfx_test <- FALSE
393
391
if (! is.null(rfx_group_ids_train )) {
394
392
group_ids_factor <- factor (rfx_group_ids_train )
395
393
rfx_group_ids_train <- as.integer(group_ids_factor )
396
- has_rfx <- T
394
+ has_rfx <- TRUE
397
395
if (! is.null(rfx_group_ids_test )) {
398
396
group_ids_factor_test <- factor (rfx_group_ids_test , levels = levels(group_ids_factor ))
399
397
if (sum(is.na(group_ids_factor_test )) > 0 ) {
400
398
stop(" All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" )
401
399
}
402
400
rfx_group_ids_test <- as.integer(group_ids_factor_test )
403
- has_rfx_test <- T
401
+ has_rfx_test <- TRUE
404
402
}
405
403
}
406
404
@@ -432,13 +430,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
432
430
}
433
431
434
432
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
435
- has_basis_rfx <- F
433
+ has_basis_rfx <- FALSE
436
434
num_basis_rfx <- 0
437
435
if (has_rfx ) {
438
436
if (is.null(rfx_basis_train )) {
439
437
rfx_basis_train <- matrix (rep(1 ,nrow(X_train )), nrow = nrow(X_train ), ncol = 1 )
440
438
} else {
441
- has_basis_rfx <- T
439
+ has_basis_rfx <- TRUE
442
440
num_basis_rfx <- ncol(rfx_basis_train )
443
441
}
444
442
num_rfx_groups <- length(unique(rfx_group_ids_train ))
@@ -520,40 +518,40 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
520
518
# Unpack model type info
521
519
if (leaf_model_mean_forest == 0 ) {
522
520
leaf_dimension = 1
523
- is_leaf_constant = T
524
- leaf_regression = F
521
+ is_leaf_constant = TRUE
522
+ leaf_regression = FALSE
525
523
} else if (leaf_model_mean_forest == 1 ) {
526
524
stopifnot(has_basis )
527
525
stopifnot(ncol(leaf_basis_train ) == 1 )
528
526
leaf_dimension = 1
529
- is_leaf_constant = F
530
- leaf_regression = T
527
+ is_leaf_constant = FALSE
528
+ leaf_regression = TRUE
531
529
} else if (leaf_model_mean_forest == 2 ) {
532
530
stopifnot(has_basis )
533
531
stopifnot(ncol(leaf_basis_train ) > 1 )
534
532
leaf_dimension = ncol(leaf_basis_train )
535
- is_leaf_constant = F
536
- leaf_regression = T
533
+ is_leaf_constant = FALSE
534
+ leaf_regression = TRUE
537
535
if (sample_sigma_leaf ) {
538
536
warning(" Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." )
539
- sample_sigma_leaf <- F
537
+ sample_sigma_leaf <- FALSE
540
538
}
541
539
}
542
540
543
541
# Data
544
542
if (leaf_regression ) {
545
543
forest_dataset_train <- createForestDataset(X_train , leaf_basis_train )
546
544
if (has_test ) forest_dataset_test <- createForestDataset(X_test , leaf_basis_test )
547
- requires_basis <- T
545
+ requires_basis <- TRUE
548
546
} else {
549
547
forest_dataset_train <- createForestDataset(X_train )
550
548
if (has_test ) forest_dataset_test <- createForestDataset(X_test )
551
- requires_basis <- F
549
+ requires_basis <- FALSE
552
550
}
553
551
outcome_train <- createOutcome(resid_train )
554
552
555
553
# Random number generator (std::mt19937)
556
- if (is.null(random_seed )) random_seed = sample(1 : 10000 ,1 ,F )
554
+ if (is.null(random_seed )) random_seed = sample(1 : 10000 ,1 ,FALSE )
557
555
rng <- createCppRNG(random_seed )
558
556
559
557
# Sampling data structures
@@ -630,7 +628,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
630
628
if (requires_basis ) init_values_mean_forest <- rep(0 . , ncol(leaf_basis_train ))
631
629
else init_values_mean_forest <- 0 .
632
630
active_forest_mean $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_mean , leaf_model_mean_forest , init_values_mean_forest )
633
- active_forest_mean $ adjust_residual(forest_dataset_train , outcome_train , forest_model_mean , requires_basis , F )
631
+ active_forest_mean $ adjust_residual(forest_dataset_train , outcome_train , forest_model_mean , requires_basis , FALSE )
634
632
}
635
633
636
634
# Initialize the leaves of each tree in the variance forest
@@ -643,8 +641,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
643
641
if (num_gfr > 0 ){
644
642
for (i in 1 : num_gfr ) {
645
643
# Keep all GFR samples at this stage -- remove from ForestSamples after MCMC
646
- # keep_sample <- ifelse(keep_gfr, T, F )
647
- keep_sample <- T
644
+ # keep_sample <- ifelse(keep_gfr, TRUE, FALSE )
645
+ keep_sample <- TRUE
648
646
if (keep_sample ) sample_counter <- sample_counter + 1
649
647
# Print progress
650
648
if (verbose ) {
@@ -657,14 +655,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
657
655
forest_model_mean $ sample_one_iteration(
658
656
forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_mean ,
659
657
active_forest = active_forest_mean , rng = rng , forest_model_config = forest_model_config_mean ,
660
- global_model_config = global_model_config , keep_forest = keep_sample , gfr = T
658
+ global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
661
659
)
662
660
}
663
661
if (include_variance_forest ) {
664
662
forest_model_variance $ sample_one_iteration(
665
663
forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_variance ,
666
664
active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
667
- global_model_config = global_model_config , keep_forest = keep_sample , gfr = T
665
+ global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
668
666
)
669
667
}
670
668
if (sample_sigma_global ) {
@@ -771,11 +769,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
771
769
is_mcmc <- i > (num_gfr + num_burnin )
772
770
if (is_mcmc ) {
773
771
mcmc_counter <- i - (num_gfr + num_burnin )
774
- if (mcmc_counter %% keep_every == 0 ) keep_sample <- T
775
- else keep_sample <- F
772
+ if (mcmc_counter %% keep_every == 0 ) keep_sample <- TRUE
773
+ else keep_sample <- FALSE
776
774
} else {
777
- if (keep_burnin ) keep_sample <- T
778
- else keep_sample <- F
775
+ if (keep_burnin ) keep_sample <- TRUE
776
+ else keep_sample <- FALSE
779
777
}
780
778
if (keep_sample ) sample_counter <- sample_counter + 1
781
779
# Print progress
@@ -796,14 +794,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
796
794
forest_model_mean $ sample_one_iteration(
797
795
forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_mean ,
798
796
active_forest = active_forest_mean , rng = rng , forest_model_config = forest_model_config_mean ,
799
- global_model_config = global_model_config , keep_forest = keep_sample , gfr = F
797
+ global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
800
798
)
801
799
}
802
800
if (include_variance_forest ) {
803
801
forest_model_variance $ sample_one_iteration(
804
802
forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_variance ,
805
803
active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
806
- global_model_config = global_model_config , keep_forest = keep_sample , gfr = F
804
+ global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
807
805
)
808
806
}
809
807
if (sample_sigma_global ) {
@@ -994,8 +992,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
994
992
# ' bart_model <- bart(X_train = X_train, y_train = y_train,
995
993
# ' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
996
994
# ' y_hat_test <- predict(bart_model, X_test)$y_hat
997
- # ' plot(rowMeans(y_hat_test), y_test, xlab = "predicted", ylab = "actual")
998
- # ' abline(0,1,col="red",lty=3,lwd=3)
999
995
predict.bartmodel <- function (object , X , leaf_basis = NULL , rfx_group_ids = NULL , rfx_basis = NULL , ... ){
1000
996
# Preprocess covariates
1001
997
if ((! is.data.frame(X )) && (! is.matrix(X ))) {
@@ -1033,15 +1029,15 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
1033
1029
}
1034
1030
1035
1031
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
1036
- has_rfx <- F
1032
+ has_rfx <- FALSE
1037
1033
if (! is.null(rfx_group_ids )) {
1038
1034
rfx_unique_group_ids <- object $ rfx_unique_group_ids
1039
1035
group_ids_factor <- factor (rfx_group_ids , levels = rfx_unique_group_ids )
1040
1036
if (sum(is.na(group_ids_factor )) > 0 ) {
1041
1037
stop(" All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" )
1042
1038
}
1043
1039
rfx_group_ids <- as.integer(group_ids_factor )
1044
- has_rfx <- T
1040
+ has_rfx <- TRUE
1045
1041
}
1046
1042
1047
1043
# Produce basis for the "intercept-only" random effects case
@@ -1557,8 +1553,6 @@ createBARTModelFromJsonFile <- function(json_filename){
1557
1553
# ' bart_json <- saveBARTModelToJsonString(bart_model)
1558
1554
# ' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
1559
1555
# ' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
1560
- # ' plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip,
1561
- # ' xlab = "original", ylab = "roundtrip")
1562
1556
createBARTModelFromJsonString <- function (json_string ){
1563
1557
# Load a `CppJson` object from string
1564
1558
bart_json <- createCppJsonString(json_string )
0 commit comments