Skip to content

Commit f5b2c72

Browse files
committed
Merge branch 'main' into python-update-0.1.1
2 parents f454c1f + 5bbac93 commit f5b2c72

File tree

4 files changed

+46
-42
lines changed

4 files changed

+46
-42
lines changed

R/bart.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
826826
if ((!keep_gfr) && (num_gfr > 0)) {
827827
for (i in 1:num_gfr) {
828828
if (include_mean_forest) {
829-
forest_samples_mean$delete_sample(i-1)
829+
forest_samples_mean$delete_sample(0)
830830
}
831831
if (include_variance_forest) {
832-
forest_samples_variance$delete_sample(i-1)
832+
forest_samples_variance$delete_sample(0)
833833
}
834834
if (has_rfx) {
835-
rfx_samples$delete_sample(i-1)
835+
rfx_samples$delete_sample(0)
836836
}
837837
}
838838
if (sample_sigma_global) {

R/bcf.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,13 +1214,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12141214
# Remove GFR samples if they are not to be retained
12151215
if ((!keep_gfr) && (num_gfr > 0)) {
12161216
for (i in 1:num_gfr) {
1217-
forest_samples_mu$delete_sample(i-1)
1218-
forest_samples_tau$delete_sample(i-1)
1217+
forest_samples_mu$delete_sample(0)
1218+
forest_samples_tau$delete_sample(0)
12191219
if (include_variance_forest) {
1220-
forest_samples_variance$delete_sample(i-1)
1220+
forest_samples_variance$delete_sample(0)
12211221
}
12221222
if (has_rfx) {
1223-
rfx_samples$delete_sample(i-1)
1223+
rfx_samples$delete_sample(0)
12241224
}
12251225
}
12261226
if (sample_sigma_global) {

stochtree/bart.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,20 +244,22 @@ def sample(
244244
drop_vars_mean = mean_forest_params_updated["drop_vars"]
245245

246246
# 3. Variance forest parameters
247-
num_trees_variance = variance_forest_params_updated["num_trees"]
248-
alpha_variance = variance_forest_params_updated["alpha"]
249-
beta_variance = variance_forest_params_updated["beta"]
250-
min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"]
251-
max_depth_variance = variance_forest_params_updated["max_depth"]
252-
a_0 = variance_forest_params_updated["leaf_prior_calibration_param"]
253-
variance_forest_leaf_init = variance_forest_params_updated[
254-
"var_forest_leaf_init"
255-
]
256-
a_forest = variance_forest_params_updated["var_forest_prior_shape"]
257-
b_forest = variance_forest_params_updated["var_forest_prior_scale"]
258-
keep_vars_variance = variance_forest_params_updated["keep_vars"]
259-
drop_vars_variance = variance_forest_params_updated["drop_vars"]
260-
247+
num_trees_variance = variance_forest_params_updated['num_trees']
248+
alpha_variance = variance_forest_params_updated['alpha']
249+
beta_variance = variance_forest_params_updated['beta']
250+
min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf']
251+
max_depth_variance = variance_forest_params_updated['max_depth']
252+
a_0 = variance_forest_params_updated['leaf_prior_calibration_param']
253+
variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init']
254+
a_forest = variance_forest_params_updated['var_forest_prior_shape']
255+
b_forest = variance_forest_params_updated['var_forest_prior_scale']
256+
keep_vars_variance = variance_forest_params_updated['keep_vars']
257+
drop_vars_variance = variance_forest_params_updated['drop_vars']
258+
259+
# Override keep_gfr if there are no MCMC samples
260+
if num_mcmc == 0:
261+
keep_gfr = True
262+
261263
# Check that num_chains >= 1
262264
if not isinstance(num_chains, Integral) or num_chains < 1:
263265
raise ValueError("num_chains must be an integer greater than 0")
@@ -1091,11 +1093,11 @@ def sample(
10911093
if not keep_gfr and num_gfr > 0:
10921094
for i in range(num_gfr):
10931095
if self.include_mean_forest:
1094-
self.forest_container_mean.delete_sample(i)
1096+
self.forest_container_mean.delete_sample(0)
10951097
if self.include_variance_forest:
1096-
self.forest_container_variance.delete_sample(i)
1098+
self.forest_container_variance.delete_sample(0)
10971099
if self.has_rfx:
1098-
self.rfx_container.delete_sample(i)
1100+
self.rfx_container.delete_sample(0)
10991101
if self.sample_sigma_global:
11001102
self.global_var_samples = self.global_var_samples[num_gfr:]
11011103
if self.sample_sigma_leaf:

stochtree/bcf.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -319,20 +319,22 @@ def sample(
319319
drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"]
320320

321321
# 4. Variance forest parameters
322-
num_trees_variance = variance_forest_params_updated["num_trees"]
323-
alpha_variance = variance_forest_params_updated["alpha"]
324-
beta_variance = variance_forest_params_updated["beta"]
325-
min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"]
326-
max_depth_variance = variance_forest_params_updated["max_depth"]
327-
a_0 = variance_forest_params_updated["leaf_prior_calibration_param"]
328-
variance_forest_leaf_init = variance_forest_params_updated[
329-
"var_forest_leaf_init"
330-
]
331-
a_forest = variance_forest_params_updated["var_forest_prior_shape"]
332-
b_forest = variance_forest_params_updated["var_forest_prior_scale"]
333-
keep_vars_variance = variance_forest_params_updated["keep_vars"]
334-
drop_vars_variance = variance_forest_params_updated["drop_vars"]
335-
322+
num_trees_variance = variance_forest_params_updated['num_trees']
323+
alpha_variance = variance_forest_params_updated['alpha']
324+
beta_variance = variance_forest_params_updated['beta']
325+
min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf']
326+
max_depth_variance = variance_forest_params_updated['max_depth']
327+
a_0 = variance_forest_params_updated['leaf_prior_calibration_param']
328+
variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init']
329+
a_forest = variance_forest_params_updated['var_forest_prior_shape']
330+
b_forest = variance_forest_params_updated['var_forest_prior_scale']
331+
keep_vars_variance = variance_forest_params_updated['keep_vars']
332+
drop_vars_variance = variance_forest_params_updated['drop_vars']
333+
334+
# Override keep_gfr if there are no MCMC samples
335+
if num_mcmc == 0:
336+
keep_gfr = True
337+
336338
# Variable weight preprocessing (and initialization if necessary)
337339
if variable_weights is None:
338340
if X_train.ndim > 1:
@@ -1772,12 +1774,12 @@ def sample(
17721774
# Remove GFR samples if they are not to be retained
17731775
if not keep_gfr and num_gfr > 0:
17741776
for i in range(num_gfr):
1775-
self.forest_container_mu.delete_sample(i)
1776-
self.forest_container_tau.delete_sample(i)
1777+
self.forest_container_mu.delete_sample(0)
1778+
self.forest_container_tau.delete_sample(0)
17771779
if self.include_variance_forest:
1778-
self.forest_container_variance.delete_sample(i)
1780+
self.forest_container_variance.delete_sample(0)
17791781
if self.has_rfx:
1780-
self.rfx_container.delete_sample(i)
1782+
self.rfx_container.delete_sample(0)
17811783
if self.adaptive_coding:
17821784
self.b1_samples = self.b1_samples[num_gfr:]
17831785
self.b0_samples = self.b0_samples[num_gfr:]

0 commit comments

Comments
 (0)