Skip to content

Commit fc24040

Browse files
committed
Fixed indexing issue in post-sampler cleanup in R and Python
1 parent 06c803a commit fc24040

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

R/bart.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
826826
if ((!keep_gfr) && (num_gfr > 0)) {
827827
for (i in 1:num_gfr) {
828828
if (include_mean_forest) {
829-
forest_samples_mean$delete_sample(i-1)
829+
forest_samples_mean$delete_sample(0)
830830
}
831831
if (include_variance_forest) {
832-
forest_samples_variance$delete_sample(i-1)
832+
forest_samples_variance$delete_sample(0)
833833
}
834834
if (has_rfx) {
835-
rfx_samples$delete_sample(i-1)
835+
rfx_samples$delete_sample(0)
836836
}
837837
}
838838
if (sample_sigma_global) {

R/bcf.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,13 +1214,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12141214
# Remove GFR samples if they are not to be retained
12151215
if ((!keep_gfr) && (num_gfr > 0)) {
12161216
for (i in 1:num_gfr) {
1217-
forest_samples_mu$delete_sample(i-1)
1218-
forest_samples_tau$delete_sample(i-1)
1217+
forest_samples_mu$delete_sample(0)
1218+
forest_samples_tau$delete_sample(0)
12191219
if (include_variance_forest) {
1220-
forest_samples_variance$delete_sample(i-1)
1220+
forest_samples_variance$delete_sample(0)
12211221
}
12221222
if (has_rfx) {
1223-
rfx_samples$delete_sample(i-1)
1223+
rfx_samples$delete_sample(0)
12241224
}
12251225
}
12261226
if (sample_sigma_global) {

stochtree/bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,9 +684,9 @@ def sample(self, X_train: Union[np.array, pd.DataFrame], y_train: np.array, basi
684684
if not keep_gfr and num_gfr > 0:
685685
for i in range(num_gfr):
686686
if self.include_mean_forest:
687-
self.forest_container_mean.delete_sample(i)
687+
self.forest_container_mean.delete_sample(0)
688688
if self.include_variance_forest:
689-
self.forest_container_variance.delete_sample(i)
689+
self.forest_container_variance.delete_sample(0)
690690
if self.sample_sigma_global:
691691
self.global_var_samples = self.global_var_samples[num_gfr:]
692692
if self.sample_sigma_leaf:

stochtree/bcf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,10 +1055,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
10551055
# Remove GFR samples if they are not to be retained
10561056
if not keep_gfr and num_gfr > 0:
10571057
for i in range(num_gfr):
1058-
self.forest_container_mu.delete_sample(i)
1059-
self.forest_container_tau.delete_sample(i)
1058+
self.forest_container_mu.delete_sample(0)
1059+
self.forest_container_tau.delete_sample(0)
10601060
if self.include_variance_forest:
1061-
self.forest_container_variance.delete_sample(i)
1061+
self.forest_container_variance.delete_sample(0)
10621062
if self.adaptive_coding:
10631063
self.b1_samples = self.b1_samples[num_gfr:]
10641064
self.b0_samples = self.b0_samples[num_gfr:]

0 commit comments

Comments
 (0)