Skip to content

Commit 82fe588

Browse files
authored
Merge pull request #102 from StochasticTree/update_outcome
Update to allow overwriting the outcome in the R prototype interface
2 parents 8bffc8c + 84719bd commit 82fe588

File tree

8 files changed

+139
-2
lines changed

8 files changed

+139
-2
lines changed

R/cpp11.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ subtract_from_column_vector_cpp <- function(outcome, update_vector) {
5252
invisible(.Call(`_stochtree_subtract_from_column_vector_cpp`, outcome, update_vector))
5353
}
5454

55+
overwrite_column_vector_cpp <- function(outcome, new_vector) {
56+
invisible(.Call(`_stochtree_overwrite_column_vector_cpp`, outcome, new_vector))
57+
}
58+
59+
propagate_trees_column_vector_cpp <- function(tracker, residual) {
60+
invisible(.Call(`_stochtree_propagate_trees_column_vector_cpp`, tracker, residual))
61+
}
62+
5563
get_residual_cpp <- function(vector_ptr) {
5664
.Call(`_stochtree_get_residual_cpp`, vector_ptr)
5765
}

R/data.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,32 @@ Outcome <- R6::R6Class(
140140
}
141141
}
142142
subtract_from_column_vector_cpp(self$data_ptr, update_vector)
143+
},
144+
145+
#' @description
146+
#' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector`
147+
#' @param new_vector Vector from which to overwrite the current data
148+
#' @return NULL
149+
update_data = function(new_vector) {
150+
if (!is.numeric(new_vector)) {
151+
stop("update_vector must be a numeric vector or 2d matrix")
152+
} else {
153+
dim_vec <- dim(new_vector)
154+
if (!is.null(dim_vec)) {
155+
if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d")
156+
new_vector <- as.numeric(new_vector)
157+
}
158+
}
159+
overwrite_column_vector_cpp(self$data_ptr, new_vector)
160+
},
161+
162+
#' @description
163+
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
164+
#' This function is run after the `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
165+
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
166+
#' @return NULL
167+
propagate_trees_new_outcome = function(forest_model) {
168+
propagate_trees_column_vector_cpp(forest_model$tracker_ptr, self$data_ptr)
143169
}
144170
)
145171
)

include/stochtree/data.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ColumnVector {
126126
void LoadData(double* data_ptr, data_size_t num_row);
127127
void AddToData(double* data_ptr, data_size_t num_row);
128128
void SubtractFromData(double* data_ptr, data_size_t num_row);
129+
void OverwriteData(double* data_ptr, data_size_t num_row);
129130
inline data_size_t NumRows() {return data_.size();}
130131
inline Eigen::VectorXd& GetData() {return data_;}
131132
private:

include/stochtree/tree_sampler.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,6 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas
196196
tracker.SyncPredictions();
197197
}
198198

199-
200-
201199
static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
202200
bool requires_basis, std::function<double(double, double)> op) {
203201
data_size_t n = dataset.GetCovariates().rows();
@@ -225,6 +223,20 @@ static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestData
225223
tracker.SyncPredictions();
226224
}
227225

226+
static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
227+
data_size_t n = residual.NumRows();
228+
double pred_value;
229+
double prev_outcome;
230+
double new_resid;
231+
for (data_size_t i = 0; i < n; i++) {
232+
prev_outcome = residual.GetElement(i);
233+
pred_value = tracker.GetSamplePrediction(i);
234+
// Run op (either plus or minus) on the residual and the new prediction
235+
new_resid = prev_outcome - pred_value;
236+
residual.SetElement(i, new_resid);
237+
}
238+
}
239+
228240
static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
229241
bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
230242
data_size_t n = dataset.GetCovariates().rows();

man/Outcome.Rd

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/R_data.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <cpp11.hpp>
22
#include <stochtree/container.h>
33
#include <stochtree/data.h>
4+
#include <stochtree/partition_tracker.h>
5+
#include <stochtree/tree_sampler.h>
46
#include <memory>
57
#include <vector>
68

@@ -136,6 +138,25 @@ void subtract_from_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVe
136138
UNPROTECT(1);
137139
}
138140

141+
[[cpp11::register]]
142+
void overwrite_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles new_vector) {
143+
// Unpack pointers to data and dimensions
144+
StochTree::data_size_t n = new_vector.size();
145+
double* update_data_ptr = REAL(PROTECT(new_vector));
146+
147+
// Add to the outcome data using the C++ API
148+
outcome->OverwriteData(update_data_ptr, n);
149+
150+
// Unprotect pointers to R data
151+
UNPROTECT(1);
152+
}
153+
154+
[[cpp11::register]]
155+
void propagate_trees_column_vector_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker,
156+
cpp11::external_pointer<StochTree::ColumnVector> residual) {
157+
StochTree::UpdateResidualNewOutcome(*tracker, *residual);
158+
}
159+
139160
[[cpp11::register]]
140161
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr) {
141162
// Initialize output vector

src/cpp11.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ extern "C" SEXP _stochtree_subtract_from_column_vector_cpp(SEXP outcome, SEXP up
103103
END_CPP11
104104
}
105105
// R_data.cpp
106+
void overwrite_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles new_vector);
107+
extern "C" SEXP _stochtree_overwrite_column_vector_cpp(SEXP outcome, SEXP new_vector) {
108+
BEGIN_CPP11
109+
overwrite_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(new_vector));
110+
return R_NilValue;
111+
END_CPP11
112+
}
113+
// R_data.cpp
114+
void propagate_trees_column_vector_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker, cpp11::external_pointer<StochTree::ColumnVector> residual);
115+
extern "C" SEXP _stochtree_propagate_trees_column_vector_cpp(SEXP tracker, SEXP residual) {
116+
BEGIN_CPP11
117+
propagate_trees_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(residual));
118+
return R_NilValue;
119+
END_CPP11
120+
}
121+
// R_data.cpp
106122
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr);
107123
extern "C" SEXP _stochtree_get_residual_cpp(SEXP vector_ptr) {
108124
BEGIN_CPP11
@@ -1014,9 +1030,11 @@ static const R_CallMethodDef CallEntries[] = {
10141030
{"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1},
10151031
{"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1},
10161032
{"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1},
1033+
{"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2},
10171034
{"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2},
10181035
{"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2},
10191036
{"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3},
1037+
{"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2},
10201038
{"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3},
10211039
{"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3},
10221040
{"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2},

src/data.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ void ColumnVector::SubtractFromData(double* data_ptr, data_size_t num_row) {
110110
UpdateData(data_ptr, num_row, std::minus<double>());
111111
}
112112

113+
void ColumnVector::OverwriteData(double* data_ptr, data_size_t num_row) {
114+
double ptr_val;
115+
for (data_size_t i = 0; i < num_row; ++i) {
116+
ptr_val = static_cast<double>(*(data_ptr + i));
117+
data_(i) = ptr_val;
118+
}
119+
}
120+
113121
void ColumnVector::UpdateData(double* data_ptr, data_size_t num_row, std::function<double(double, double)> op) {
114122
double ptr_val;
115123
double updated_val;

0 commit comments

Comments
 (0)