Skip to content

Commit 5cab0f0

Browse files
authored
Merge pull request #100 from StochasticTree/forest_leaf_scale_enhancement
Added function to extract forest leaf scale parameters
2 parents c2f6538 + 0961921 commit 5cab0f0

File tree

5 files changed

+138
-19
lines changed

5 files changed

+138
-19
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export(bart)
88
export(bcf)
99
export(calibrate_inverse_gamma_error_variance)
1010
export(computeForestLeafIndices)
11+
export(computeForestLeafVariances)
1112
export(computeMaxLeafIndex)
1213
export(convertBARTModelToJson)
1314
export(convertBCFModelToJson)

R/kernel.R

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
#' Compute and return a vector representation of a forest's leaf predictions for
1+
#' Compute vector of forest leaf indices
2+
#'
3+
#' @description Compute and return a vector representation of a forest's leaf predictions for
24
#' every observation in a dataset.
3-
#' The vector has a "column-major" format that can be easily re-represented as
4-
#' as a CSC sparse matrix: elements are organized so that the first `n` elements
5+
#'
6+
#' The vector has a "row-major" format that can be easily re-represented as
7+
#' as a CSR sparse matrix: elements are organized so that the first `n` elements
58
#' correspond to leaf predictions for all `n` observations in a dataset for the
69
#' first tree in an ensemble, the next `n` elements correspond to predictions for
710
#' the second tree and so on. The "data" for each element corresponds to a uniquely
@@ -12,7 +15,7 @@
1215
#' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample
1316
#' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.
1417
#' @param forest_type Which forest to use from `model_object`.
15-
#' Valid inputs depend on the model type, and whether or not a
18+
#' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model.
1619
#'
1720
#' **1. BART**
1821
#'
@@ -88,6 +91,89 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type, fore
8891
return(leaf_ind_matrix)
8992
}
9093

94+
#' Compute vector of forest leaf scale parameters
95+
#'
96+
#' @description Return each forest's leaf node scale parameters.
97+
#'
98+
#' If leaf scale is not sampled for the forest in question, throws an error that the
99+
#' leaf model does not have a stochastic scale parameter.
100+
#'
101+
#' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample
102+
#' @param forest_type Which forest to use from `model_object`.
103+
#' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model.
104+
#'
105+
#' **1. BART**
106+
#'
107+
#' - `'mean'`: Extracts leaf indices for the mean forest
108+
#' - `'variance'`: Extracts leaf indices for the variance forest
109+
#'
110+
#' **2. BCF**
111+
#'
112+
#' - `'prognostic'`: Extracts leaf indices for the prognostic forest
113+
#' - `'treatment'`: Extracts leaf indices for the treatment effect forest
114+
#' - `'variance'`: Extracts leaf indices for the variance forest
115+
#'
116+
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
117+
#' this function will return leaf indices for every sample of a forest.
118+
#' This function uses 1-indexing, so the first forest sample corresponds to `forest_num = 1`, and so on.
119+
#' @return Vector of size `length(forest_inds)` with the leaf scale parameter for each requested forest.
120+
#' @export
121+
computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) {
122+
# Extract relevant forest container
123+
stopifnot(class(model_object) %in% c("bartmodel", "bcf"))
124+
model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf")
125+
if (model_type == "bart") {
126+
stopifnot(forest_type %in% c("mean", "variance"))
127+
if (forest_type=="mean") {
128+
if (!model_object$model_params$include_mean_forest) {
129+
stop("Mean forest was not sampled in the bart model provided")
130+
}
131+
if (model_object$model_params$sample_sigma_leaf == F) {
132+
stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided")
133+
}
134+
leaf_scale_vector <- model_object$sigma2_leaf_samples
135+
} else if (forest_type=="variance") {
136+
if (!model_object$model_params$include_variance_forest) {
137+
stop("Variance forest was not sampled in the bart model provided")
138+
}
139+
stop("Leaf scale parameter was not sampled for the variance forest in the bart model provided")
140+
}
141+
} else {
142+
stopifnot(forest_type %in% c("prognostic", "treatment", "variance"))
143+
if (forest_type=="prognostic") {
144+
if (model_object$model_params$sample_sigma_leaf_mu == F) {
145+
stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided")
146+
}
147+
leaf_scale_vector <- model_object$sigma_leaf_mu_samples
148+
} else if (forest_type=="treatment") {
149+
if (model_object$model_params$sample_sigma_leaf_tau == F) {
150+
stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided")
151+
}
152+
leaf_scale_vector <- model_object$sigma_leaf_tau_samples
153+
} else if (forest_type=="variance") {
154+
if (!model_object$model_params$include_variance_forest) {
155+
stop("Variance forest was not sampled in the bcf model provided")
156+
}
157+
stop("Leaf scale parameter was not sampled for the variance forest in the bcf model provided")
158+
}
159+
}
160+
161+
# Preprocess forest indices
162+
num_forests <- forest_container$num_samples()
163+
if (is.null(forest_inds)) {
164+
forest_inds <- as.integer(1:num_forests)
165+
} else {
166+
stopifnot(all(forest_inds <= num_forests))
167+
stopifnot(all(forest_inds >= 1))
168+
forest_inds <- as.integer(forest_inds)
169+
}
170+
171+
# Gather leaf scale parameters
172+
leaf_scale_params <- leaf_scale_vector[forest_inds]
173+
174+
return(leaf_scale_params)
175+
}
176+
91177
#' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container.
92178
#'
93179
#' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample

_pkgdown.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ reference:
8686
- calibrate_inverse_gamma_error_variance
8787
- preprocessBartParams
8888
- preprocessBcfParams
89-
- computeForestLeafIndices
9089
- computeMaxLeafIndex
90+
- computeForestLeafIndices
91+
- computeForestLeafVariances
9192

9293
- subtitle: Random Effects
9394
desc: >
@@ -104,7 +105,6 @@ reference:
104105
- getRandomEffectSamples.bcf
105106
- sample_sigma2_one_iteration
106107
- sample_tau_one_iteration
107-
- sample_tau_one_iteration
108108

109109
- title: Package info
110110
desc: >

man/computeForestLeafIndices.Rd

Lines changed: 5 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/computeForestLeafVariances.Rd

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)