Skip to content

Commit f2bfe21

Browse files
committed
Fixed minor BCF bug
1 parent 1ccdd1b commit f2bfe21

File tree

3 files changed

+163
-16
lines changed

3 files changed

+163
-16
lines changed

R/bcf.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
286286
previous_forest_samples_variance <- previous_bcf_model$forests_variance
287287
} else previous_forest_samples_variance <- NULL
288288
if (previous_bcf_model$model_params$sample_sigma_global) {
289-
previous_global_var_samples <- previous_bcf_model$sigma2_samples*(
290-
previous_var_scale / (previous_y_scale*previous_y_scale)
289+
previous_global_var_samples <- previous_bcf_model$sigma2_samples / (
290+
previous_y_scale*previous_y_scale
291291
)
292292
} else previous_global_var_samples <- NULL
293293
if (previous_bcf_model$model_params$sample_sigma_leaf_mu) {
@@ -313,7 +313,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
313313
} else {
314314
previous_y_bar <- NULL
315315
previous_y_scale <- NULL
316-
previous_var_scale <- NULL
317316
previous_global_var_samples <- NULL
318317
previous_leaf_var_mu_samples <- NULL
319318
previous_leaf_var_tau_samples <- NULL

test/R/testthat/test-bcf.R

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ test_that("MCMC BCF", {
8181
)
8282
})
8383

84-
test_that("GFR BART", {
84+
test_that("GFR BCF", {
8585
skip_on_cran()
8686

8787
# Generate simulated data
@@ -90,21 +90,21 @@ test_that("GFR BART", {
9090
X <- matrix(runif(n*p), ncol = p)
9191
mu_X <- (
9292
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
93-
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
94-
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
95-
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
93+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
94+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
95+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
9696
)
9797
pi_X <- (
9898
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
99-
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
100-
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
101-
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
99+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
100+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
101+
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
102102
)
103103
tau_X <- (
104104
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
105-
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
106-
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
107-
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
105+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
106+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
107+
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
108108
)
109109
Z <- rbinom(n, 1, pi_X)
110110
noise_sd <- 1
@@ -181,3 +181,151 @@ test_that("GFR BART", {
181181
num_mcmc = 10, general_params = general_param_list)
182182
)
183183
})
184+
185+
test_that("Warmstart BCF", {
186+
skip_on_cran()
187+
188+
# Generate simulated data
189+
n <- 100
190+
p <- 5
191+
X <- matrix(runif(n*p), ncol = p)
192+
mu_X <- (
193+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
194+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
195+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
196+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
197+
)
198+
pi_X <- (
199+
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
200+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
201+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
202+
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
203+
)
204+
tau_X <- (
205+
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
206+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
207+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
208+
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
209+
)
210+
Z <- rbinom(n, 1, pi_X)
211+
noise_sd <- 1
212+
y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd)
213+
test_set_pct <- 0.2
214+
n_test <- round(test_set_pct*n)
215+
n_train <- n - n_test
216+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
217+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
218+
X_test <- X[test_inds,]
219+
X_train <- X[train_inds,]
220+
Z_test <- Z[test_inds]
221+
Z_train <- Z[train_inds]
222+
pi_test <- pi[test_inds]
223+
pi_train <- pi[train_inds]
224+
mu_test <- mu_X[test_inds]
225+
mu_train <- mu_X[train_inds]
226+
tau_test <- tau_X[test_inds]
227+
tau_train <- tau_X[train_inds]
228+
y_test <- y[test_inds]
229+
y_train <- y[train_inds]
230+
231+
# Run a BCF model with only GFR
232+
general_param_list <- list(num_chains = 1, keep_every = 1)
233+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
234+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
235+
propensity_test = pi_test, num_gfr = 10, num_burnin = 0,
236+
num_mcmc = 0, general_params = general_param_list)
237+
238+
# Save to JSON string
239+
bcf_model_json_string <- saveBCFModelToJsonString(bcf_model)
240+
241+
# Run a new BCF chain from the existing (X)BCF model
242+
general_param_list <- list(num_chains = 3, keep_every = 5)
243+
expect_no_error(
244+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
245+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
246+
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
247+
num_mcmc = 10, previous_model_json = bcf_model_json_string,
248+
previous_model_warmstart_sample_num = 1,
249+
general_params = general_param_list)
250+
)
251+
252+
# Generate simulated data with random effects
253+
n <- 100
254+
p <- 5
255+
X <- matrix(runif(n*p), ncol = p)
256+
mu_X <- (
257+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
258+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
259+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
260+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
261+
)
262+
pi_X <- (
263+
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
264+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
265+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
266+
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
267+
)
268+
tau_X <- (
269+
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
270+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
271+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
272+
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
273+
)
274+
Z <- rbinom(n, 1, pi_X)
275+
rfx_group_ids <- sample(1:2, size = n, replace = T)
276+
rfx_basis <- rep(1, n)
277+
rfx_coefs <- c(-5, 5)
278+
rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis
279+
noise_sd <- 1
280+
y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd)
281+
test_set_pct <- 0.2
282+
n_test <- round(test_set_pct*n)
283+
n_train <- n - n_test
284+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
285+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
286+
X_test <- X[test_inds,]
287+
X_train <- X[train_inds,]
288+
Z_test <- Z[test_inds]
289+
Z_train <- Z[train_inds]
290+
pi_test <- pi[test_inds]
291+
pi_train <- pi[train_inds]
292+
mu_test <- mu_X[test_inds]
293+
mu_train <- mu_X[train_inds]
294+
tau_test <- tau_X[test_inds]
295+
tau_train <- tau_X[train_inds]
296+
rfx_group_ids_test <- rfx_group_ids[test_inds]
297+
rfx_group_ids_train <- rfx_group_ids[train_inds]
298+
rfx_basis_test <- rfx_basis[test_inds]
299+
rfx_basis_train <- rfx_basis[train_inds]
300+
y_test <- y[test_inds]
301+
y_train <- y[train_inds]
302+
303+
# Run a BCF model with only GFR
304+
general_param_list <- list(num_chains = 1, keep_every = 1)
305+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
306+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
307+
rfx_group_ids_train = rfx_group_ids_train,
308+
rfx_group_ids_test = rfx_group_ids_test,
309+
rfx_basis_train = rfx_basis_train,
310+
rfx_basis_test = rfx_basis_test,
311+
propensity_test = pi_test, num_gfr = 10, num_burnin = 0,
312+
num_mcmc = 0, general_params = general_param_list)
313+
314+
# Save to JSON string
315+
bcf_model_json_string <- saveBCFModelToJsonString(bcf_model)
316+
317+
# Run a new BCF chain from the existing (X)BCF model
318+
general_param_list <- list(num_chains = 3, keep_every = 5)
319+
expect_no_error(
320+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
321+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
322+
rfx_group_ids_train = rfx_group_ids_train,
323+
rfx_group_ids_test = rfx_group_ids_test,
324+
rfx_basis_train = rfx_basis_train,
325+
rfx_basis_test = rfx_basis_test,
326+
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
327+
num_mcmc = 10, previous_model_json = bcf_model_json_string,
328+
previous_model_warmstart_sample_num = 1,
329+
general_params = general_param_list)
330+
)
331+
})

test/R/testthat/test-serialization.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ test_that("BART Serialization", {
77
X <- matrix(runif(n*p), ncol = p)
88
f_XW <- (
99
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
10-
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
11-
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
12-
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
10+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
11+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
12+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
1313
)
1414
noise_sd <- 1
1515
y <- f_XW + rnorm(n, 0, noise_sd)

0 commit comments

Comments
 (0)