Skip to content

Commit 225e28f

Browse files
authored
Merge pull request #103 from StochasticTree/data_update_cleanup
Cleaning up R / Python data interface
2 parents 82fe588 + 7f91954 commit 225e28f

File tree

16 files changed

+148
-117
lines changed

16 files changed

+148
-117
lines changed

R/bcf.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
742742
}
743743

744744
# Update leaf predictions and residual
745-
forest_samples_tau$update_residual(forest_dataset_train, outcome_train, forest_model_tau, i-1)
745+
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1)
746746
}
747747

748748
# Sample variance parameters (if requested)
@@ -848,7 +848,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
848848
}
849849

850850
# Update leaf predictions and residual
851-
forest_samples_tau$update_residual(forest_dataset_train, outcome_train, forest_model_tau, i-1)
851+
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1)
852852
}
853853

854854
# Sample variance parameters (if requested)

R/cpp11.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ adjust_residual_forest_container_cpp <- function(data, residual, forest_samples,
332332
invisible(.Call(`_stochtree_adjust_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
333333
}
334334

335-
update_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, forest_num) {
336-
invisible(.Call(`_stochtree_update_residual_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num))
335+
propagate_basis_update_forest_container_cpp <- function(data, residual, forest_samples, tracker, forest_num) {
336+
invisible(.Call(`_stochtree_propagate_basis_update_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num))
337337
}
338338

339339
predict_forest_cpp <- function(forest_samples, dataset) {

R/data.R

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,6 @@ Outcome <- R6::R6Class(
157157
}
158158
}
159159
overwrite_column_vector_cpp(self$data_ptr, new_vector)
160-
},
161-
162-
#' @description
163-
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
164-
#' This function is run after the `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
165-
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
166-
#' @return NULL
167-
propagate_trees_new_outcome = function(forest_model) {
168-
propagate_trees_column_vector_cpp(forest_model$tracker_ptr, self$data_ptr)
169160
}
170161
)
171162
)

R/forest.R

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -177,31 +177,6 @@ ForestSamples <- R6::R6Class(
177177
)
178178
},
179179

180-
#' @description
181-
#' Updates the residual used for training tree ensembles by iteratively
182-
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
183-
#' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
184-
#'
185-
#' This is useful in cases where a basis (for e.g. leaf regression) is updated outside
186-
#' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF).
187-
#' Once a basis has been updated, the overall "function" represented by a tree model has
188-
#' changed and this should be reflected through to the residual before the next sampling loop is run.
189-
#' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
190-
#' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
191-
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
192-
#' @param forest_num Index of forest used to update residuals (starting at 1, in R style)
193-
update_residual = function(dataset, outcome, forest_model, forest_num) {
194-
stopifnot(!is.null(dataset$data_ptr))
195-
stopifnot(!is.null(outcome$data_ptr))
196-
stopifnot(!is.null(forest_model$tracker_ptr))
197-
stopifnot(!is.null(self$forest_container_ptr))
198-
199-
update_residual_forest_container_cpp(
200-
dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr,
201-
forest_model$tracker_ptr, forest_num
202-
)
203-
},
204-
205180
#' @description
206181
#' Store the trees and metadata of `ForestDataset` class in a json file
207182
#' @param json_filename Name of output json file (must end in ".json")

R/model.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,40 @@ ForestModel <- R6::R6Class(
9393
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
9494
)
9595
}
96+
},
97+
98+
#' @description
99+
#' Propagates basis update through to the (full/partial) residual by iteratively
100+
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
101+
#' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
102+
#'
103+
#' This is useful in cases where a basis (for e.g. leaf regression) is updated outside
104+
#' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF).
105+
#' Once a basis has been updated, the overall "function" represented by a tree model has
106+
#' changed and this should be reflected through to the residual before the next sampling loop is run.
107+
#' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
108+
#' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
109+
#' @param forest_samples `ForestSamples` object storing draws of tree ensembles
110+
#' @param forest_num Index of forest used to update residuals (starting at 1, in R style)
111+
propagate_basis_update = function(dataset, outcome, forest_samples, forest_num) {
112+
stopifnot(!is.null(dataset$data_ptr))
113+
stopifnot(!is.null(outcome$data_ptr))
114+
stopifnot(!is.null(self$tracker_ptr))
115+
stopifnot(!is.null(forest_samples$forest_container_ptr))
116+
117+
propagate_basis_update_forest_container_cpp(
118+
dataset$data_ptr, outcome$data_ptr, forest_samples$forest_container_ptr,
119+
self$tracker_ptr, forest_num
120+
)
121+
},
122+
123+
#' @description
124+
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
125+
#' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
126+
#' @param residual Outcome used to sample the forest
127+
#' @return NULL
128+
propagate_residual_update = function(residual) {
129+
propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr)
96130
}
97131
)
98132
)

man/ForestModel.Rd

Lines changed: 58 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ForestSamples.Rd

Lines changed: 0 additions & 31 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/Outcome.Rd

Lines changed: 0 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cpp11.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,10 @@ extern "C" SEXP _stochtree_adjust_residual_forest_container_cpp(SEXP data, SEXP
620620
END_CPP11
621621
}
622622
// forest.cpp
623-
void update_residual_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data, cpp11::external_pointer<StochTree::ColumnVector> residual, cpp11::external_pointer<StochTree::ForestContainer> forest_samples, cpp11::external_pointer<StochTree::ForestTracker> tracker, int forest_num);
624-
extern "C" SEXP _stochtree_update_residual_forest_container_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP forest_num) {
623+
void propagate_basis_update_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data, cpp11::external_pointer<StochTree::ColumnVector> residual, cpp11::external_pointer<StochTree::ForestContainer> forest_samples, cpp11::external_pointer<StochTree::ForestTracker> tracker, int forest_num);
624+
extern "C" SEXP _stochtree_propagate_basis_update_forest_container_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP forest_num) {
625625
BEGIN_CPP11
626-
update_residual_forest_container_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(data), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestContainer>>>(forest_samples), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<int>>(forest_num));
626+
propagate_basis_update_forest_container_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(data), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestContainer>>>(forest_samples), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<int>>(forest_num));
627627
return R_NilValue;
628628
END_CPP11
629629
}
@@ -1034,6 +1034,7 @@ static const R_CallMethodDef CallEntries[] = {
10341034
{"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2},
10351035
{"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2},
10361036
{"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3},
1037+
{"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5},
10371038
{"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2},
10381039
{"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3},
10391040
{"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3},
@@ -1081,7 +1082,6 @@ static const R_CallMethodDef CallEntries[] = {
10811082
{"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2},
10821083
{"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2},
10831084
{"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4},
1084-
{"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5},
10851085
{NULL, NULL, 0}
10861086
};
10871087
}

src/forest.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,11 @@ void adjust_residual_forest_container_cpp(cpp11::external_pointer<StochTree::For
356356
}
357357

358358
[[cpp11::register]]
359-
void update_residual_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data,
360-
cpp11::external_pointer<StochTree::ColumnVector> residual,
361-
cpp11::external_pointer<StochTree::ForestContainer> forest_samples,
362-
cpp11::external_pointer<StochTree::ForestTracker> tracker,
363-
int forest_num) {
359+
void propagate_basis_update_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data,
360+
cpp11::external_pointer<StochTree::ColumnVector> residual,
361+
cpp11::external_pointer<StochTree::ForestContainer> forest_samples,
362+
cpp11::external_pointer<StochTree::ForestTracker> tracker,
363+
int forest_num) {
364364
// Perform the update (addition / subtraction) operation
365365
StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num));
366366
}

0 commit comments

Comments
 (0)