Skip to content

Commit 06c803a

Browse files
committed
Override keep_gfr if no MCMC samples
1 parent e091740 commit 06c803a

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

stochtree/bart.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def sample(self, X_train: Union[np.array, pd.DataFrame], y_train: np.array, basi
226226
keep_vars_variance = variance_forest_params_updated['keep_vars']
227227
drop_vars_variance = variance_forest_params_updated['drop_vars']
228228

229+
# Override keep_gfr if there are no MCMC samples
230+
if num_mcmc == 0:
231+
keep_gfr = True
232+
229233
# Check that num_chains >= 1
230234
if not isinstance(num_chains, Integral) or num_chains < 1:
231235
raise ValueError("num_chains must be an integer greater than 0")

stochtree/bcf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
299299
keep_vars_variance = variance_forest_params_updated['keep_vars']
300300
drop_vars_variance = variance_forest_params_updated['drop_vars']
301301

302+
# Override keep_gfr if there are no MCMC samples
303+
if num_mcmc == 0:
304+
keep_gfr = True
305+
302306
# Variable weight preprocessing (and initialization if necessary)
303307
if variable_weights is None:
304308
if X_train.ndim > 1:

0 commit comments

Comments
 (0)