Skip to content

Cleaning up R / Python data interface #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}

# Update leaf predictions and residual
forest_samples_tau$update_residual(forest_dataset_train, outcome_train, forest_model_tau, i-1)
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1)
}

# Sample variance parameters (if requested)
Expand Down Expand Up @@ -848,7 +848,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}

# Update leaf predictions and residual
forest_samples_tau$update_residual(forest_dataset_train, outcome_train, forest_model_tau, i-1)
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1)
}

# Sample variance parameters (if requested)
Expand Down
4 changes: 2 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ adjust_residual_forest_container_cpp <- function(data, residual, forest_samples,
invisible(.Call(`_stochtree_adjust_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
}

update_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, forest_num) {
invisible(.Call(`_stochtree_update_residual_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num))
propagate_basis_update_forest_container_cpp <- function(data, residual, forest_samples, tracker, forest_num) {
invisible(.Call(`_stochtree_propagate_basis_update_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num))
}

predict_forest_cpp <- function(forest_samples, dataset) {
Expand Down
9 changes: 0 additions & 9 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,6 @@ Outcome <- R6::R6Class(
}
}
overwrite_column_vector_cpp(self$data_ptr, new_vector)
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
#' This function is run after the `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
#' @return NULL
propagate_trees_new_outcome = function(forest_model) {
propagate_trees_column_vector_cpp(forest_model$tracker_ptr, self$data_ptr)
}
)
)
Expand Down
25 changes: 0 additions & 25 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,6 @@ ForestSamples <- R6::R6Class(
)
},

#' @description
#' Updates the residual used for training tree ensembles by iteratively
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
#' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
#'
#' This is useful in cases where a basis (for e.g. leaf regression) is updated outside
#' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF).
#' Once a basis has been updated, the overall "function" represented by a tree model has
#' changed and this should be reflected through to the residual before the next sampling loop is run.
#' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
#' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
#' @param forest_num Index of forest used to update residuals (starting at 1, in R style)
update_residual = function(dataset, outcome, forest_model, forest_num) {
stopifnot(!is.null(dataset$data_ptr))
stopifnot(!is.null(outcome$data_ptr))
stopifnot(!is.null(forest_model$tracker_ptr))
stopifnot(!is.null(self$forest_container_ptr))

update_residual_forest_container_cpp(
dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr,
forest_model$tracker_ptr, forest_num
)
},

#' @description
#' Store the trees and metadata of `ForestDataset` class in a json file
#' @param json_filename Name of output json file (must end in ".json")
Expand Down
34 changes: 34 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,40 @@ ForestModel <- R6::R6Class(
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
)
}
},

#' @description
#' Propagates basis update through to the (full/partial) residual by iteratively
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
#' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
#'
#' This is useful in cases where a basis (for e.g. leaf regression) is updated outside
#' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF).
#' Once a basis has been updated, the overall "function" represented by a tree model has
#' changed and this should be reflected through to the residual before the next sampling loop is run.
#' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
#' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
#' @param forest_samples `ForestSamples` object storing draws of tree ensembles
#' @param forest_num Index of forest used to update residuals (starting at 1, in R style)
propagate_basis_update = function(dataset, outcome, forest_samples, forest_num) {
stopifnot(!is.null(dataset$data_ptr))
stopifnot(!is.null(outcome$data_ptr))
stopifnot(!is.null(self$tracker_ptr))
stopifnot(!is.null(forest_samples$forest_container_ptr))

propagate_basis_update_forest_container_cpp(
dataset$data_ptr, outcome$data_ptr, forest_samples$forest_container_ptr,
self$tracker_ptr, forest_num
)
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
#' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
#' @param residual Outcome used to sample the forest
#' @return NULL
propagate_residual_update = function(residual) {
propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr)
}
)
)
Expand Down
58 changes: 58 additions & 0 deletions man/ForestModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 0 additions & 31 deletions man/ForestSamples.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 0 additions & 22 deletions man/Outcome.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,10 @@ extern "C" SEXP _stochtree_adjust_residual_forest_container_cpp(SEXP data, SEXP
END_CPP11
}
// forest.cpp
void update_residual_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data, cpp11::external_pointer<StochTree::ColumnVector> residual, cpp11::external_pointer<StochTree::ForestContainer> forest_samples, cpp11::external_pointer<StochTree::ForestTracker> tracker, int forest_num);
extern "C" SEXP _stochtree_update_residual_forest_container_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP forest_num) {
void propagate_basis_update_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data, cpp11::external_pointer<StochTree::ColumnVector> residual, cpp11::external_pointer<StochTree::ForestContainer> forest_samples, cpp11::external_pointer<StochTree::ForestTracker> tracker, int forest_num);
extern "C" SEXP _stochtree_propagate_basis_update_forest_container_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP forest_num) {
BEGIN_CPP11
update_residual_forest_container_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(data), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestContainer>>>(forest_samples), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<int>>(forest_num));
propagate_basis_update_forest_container_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(data), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestContainer>>>(forest_samples), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<int>>(forest_num));
return R_NilValue;
END_CPP11
}
Expand Down Expand Up @@ -1034,6 +1034,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2},
{"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2},
{"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3},
{"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5},
{"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2},
{"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3},
{"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3},
Expand Down Expand Up @@ -1081,7 +1082,6 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2},
{"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2},
{"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4},
{"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5},
{NULL, NULL, 0}
};
}
Expand Down
10 changes: 5 additions & 5 deletions src/forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ void adjust_residual_forest_container_cpp(cpp11::external_pointer<StochTree::For
}

[[cpp11::register]]
void update_residual_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data,
cpp11::external_pointer<StochTree::ColumnVector> residual,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples,
cpp11::external_pointer<StochTree::ForestTracker> tracker,
int forest_num) {
void propagate_basis_update_forest_container_cpp(cpp11::external_pointer<StochTree::ForestDataset> data,
cpp11::external_pointer<StochTree::ColumnVector> residual,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples,
cpp11::external_pointer<StochTree::ForestTracker> tracker,
int forest_num) {
// Perform the update (addition / subtraction) operation
StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num));
}
Expand Down
Loading
Loading