Skip to content

Commit 6695966

Browse files
authored
Merge pull request #106 from StochasticTree/single-tree-predict
Added R method for single tree prediction
2 parents a28af8b + bad7dff commit 6695966

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

R/forest.R

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ ForestSamples <- R6::R6Class(
101101
#' @param forest_dataset `ForestDataset` R class
102102
#' @param forest_num Index of the forest sample within the container
103103
#' @return matrix of predictions with as many rows as in forest_dataset
104-
#' and as many columns as samples in the `ForestContainer`
104+
#' and as many columns as dimensions in the leaves of trees in `ForestContainer`
105105
predict_raw_single_forest = function(forest_dataset, forest_num) {
106106
stopifnot(!is.null(forest_dataset$data_ptr))
107107
# Unpack dimensions
@@ -113,6 +113,21 @@ ForestSamples <- R6::R6Class(
113113
return(output)
114114
},
115115

116+
#' @description
117+
#' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset`
118+
#' @param forest_dataset `ForestDataset` R class
119+
#' @param forest_num Index of the forest sample within the container
120+
#' @param tree_num Index of the tree to be queried
121+
#' @return matrix of predictions with as many rows as in `forest_dataset`
122+
#' and as many columns as dimensions in the leaves of trees in `ForestContainer`
123+
predict_raw_single_tree = function(forest_dataset, forest_num, tree_num) {
124+
stopifnot(!is.null(forest_dataset$data_ptr))
125+
126+
# Predict leaf values from forest
127+
output <- predict_forest_raw_single_tree_cpp(self$forest_container_ptr, forest_dataset$data_ptr, forest_num, tree_num)
128+
return(output)
129+
},
130+
116131
#' @description
117132
#' Set a constant predicted value for every tree in the ensemble.
118133
#' Stops program if any tree is more than a root node.

0 commit comments

Comments
 (0)