Skip to content

Update to allow overwriting the outcome in the R prototype interface #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ subtract_from_column_vector_cpp <- function(outcome, update_vector) {
invisible(.Call(`_stochtree_subtract_from_column_vector_cpp`, outcome, update_vector))
}

overwrite_column_vector_cpp <- function(outcome, new_vector) {
invisible(.Call(`_stochtree_overwrite_column_vector_cpp`, outcome, new_vector))
}

propagate_trees_column_vector_cpp <- function(tracker, residual) {
invisible(.Call(`_stochtree_propagate_trees_column_vector_cpp`, tracker, residual))
}

get_residual_cpp <- function(vector_ptr) {
.Call(`_stochtree_get_residual_cpp`, vector_ptr)
}
Expand Down
26 changes: 26 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,32 @@ Outcome <- R6::R6Class(
}
}
subtract_from_column_vector_cpp(self$data_ptr, update_vector)
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector`
#' @param new_vector Vector from which to overwrite the current data
#' @return NULL
update_data = function(new_vector) {
if (!is.numeric(new_vector)) {
stop("update_vector must be a numeric vector or 2d matrix")
} else {
dim_vec <- dim(new_vector)
if (!is.null(dim_vec)) {
if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d")
new_vector <- as.numeric(new_vector)
}
}
overwrite_column_vector_cpp(self$data_ptr, new_vector)
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
#' This function is run after the `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
#' @return NULL
propagate_trees_new_outcome = function(forest_model) {
propagate_trees_column_vector_cpp(forest_model$tracker_ptr, self$data_ptr)
}
)
)
Expand Down
1 change: 1 addition & 0 deletions include/stochtree/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class ColumnVector {
void LoadData(double* data_ptr, data_size_t num_row);
void AddToData(double* data_ptr, data_size_t num_row);
void SubtractFromData(double* data_ptr, data_size_t num_row);
void OverwriteData(double* data_ptr, data_size_t num_row);
inline data_size_t NumRows() {return data_.size();}
inline Eigen::VectorXd& GetData() {return data_;}
private:
Expand Down
16 changes: 14 additions & 2 deletions include/stochtree/tree_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,6 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas
tracker.SyncPredictions();
}



static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
bool requires_basis, std::function<double(double, double)> op) {
data_size_t n = dataset.GetCovariates().rows();
Expand Down Expand Up @@ -225,6 +223,20 @@ static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestData
tracker.SyncPredictions();
}

static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
data_size_t n = residual.NumRows();
double pred_value;
double prev_outcome;
double new_resid;
for (data_size_t i = 0; i < n; i++) {
prev_outcome = residual.GetElement(i);
pred_value = tracker.GetSamplePrediction(i);
// Run op (either plus or minus) on the residual and the new prediction
new_resid = prev_outcome - pred_value;
residual.SetElement(i, new_resid);
}
}

static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
data_size_t n = dataset.GetCovariates().rows();
Expand Down
43 changes: 43 additions & 0 deletions man/Outcome.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions src/R_data.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <cpp11.hpp>
#include <stochtree/container.h>
#include <stochtree/data.h>
#include <stochtree/partition_tracker.h>
#include <stochtree/tree_sampler.h>
#include <memory>
#include <vector>

Expand Down Expand Up @@ -136,6 +138,25 @@ void subtract_from_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVe
UNPROTECT(1);
}

[[cpp11::register]]
void overwrite_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles new_vector) {
// Unpack pointers to data and dimensions
StochTree::data_size_t n = new_vector.size();
double* update_data_ptr = REAL(PROTECT(new_vector));

// Add to the outcome data using the C++ API
outcome->OverwriteData(update_data_ptr, n);

// Unprotect pointers to R data
UNPROTECT(1);
}

[[cpp11::register]]
void propagate_trees_column_vector_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker,
cpp11::external_pointer<StochTree::ColumnVector> residual) {
StochTree::UpdateResidualNewOutcome(*tracker, *residual);
}

[[cpp11::register]]
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr) {
// Initialize output vector
Expand Down
18 changes: 18 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ extern "C" SEXP _stochtree_subtract_from_column_vector_cpp(SEXP outcome, SEXP up
END_CPP11
}
// R_data.cpp
void overwrite_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles new_vector);
extern "C" SEXP _stochtree_overwrite_column_vector_cpp(SEXP outcome, SEXP new_vector) {
BEGIN_CPP11
overwrite_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(new_vector));
return R_NilValue;
END_CPP11
}
// R_data.cpp
void propagate_trees_column_vector_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker, cpp11::external_pointer<StochTree::ColumnVector> residual);
extern "C" SEXP _stochtree_propagate_trees_column_vector_cpp(SEXP tracker, SEXP residual) {
BEGIN_CPP11
propagate_trees_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual));
return R_NilValue;
END_CPP11
}
// R_data.cpp
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr);
extern "C" SEXP _stochtree_get_residual_cpp(SEXP vector_ptr) {
BEGIN_CPP11
Expand Down Expand Up @@ -1014,9 +1030,11 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1},
{"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1},
{"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1},
{"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2},
{"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2},
{"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2},
{"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3},
{"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2},
{"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3},
{"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3},
{"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2},
Expand Down
8 changes: 8 additions & 0 deletions src/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ void ColumnVector::SubtractFromData(double* data_ptr, data_size_t num_row) {
UpdateData(data_ptr, num_row, std::minus<double>());
}

void ColumnVector::OverwriteData(double* data_ptr, data_size_t num_row) {
double ptr_val;
for (data_size_t i = 0; i < num_row; ++i) {
ptr_val = static_cast<double>(*(data_ptr + i));
data_(i) = ptr_val;
}
}

void ColumnVector::UpdateData(double* data_ptr, data_size_t num_row, std::function<double(double, double)> op) {
double ptr_val;
double updated_val;
Expand Down
Loading