|
1 |
| -#' Compute and return a vector representation of a forest's leaf predictions for |
| 1 | +#' Compute vector of forest leaf indices |
| 2 | +#' |
| 3 | +#' @description Compute and return a vector representation of a forest's leaf predictions for |
2 | 4 | #' every observation in a dataset.
|
3 |
| -#' The vector has a "column-major" format that can be easily re-represented as |
4 |
| -#' as a CSC sparse matrix: elements are organized so that the first `n` elements |
| 5 | +#' |
| 6 | +#' The vector has a "row-major" format that can be easily re-represented as |
| 7 | +#' as a CSR sparse matrix: elements are organized so that the first `n` elements |
5 | 8 | #' correspond to leaf predictions for all `n` observations in a dataset for the
|
6 | 9 | #' first tree in an ensemble, the next `n` elements correspond to predictions for
|
7 | 10 | #' the second tree and so on. The "data" for each element corresponds to a uniquely
|
|
12 | 15 | #' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample
|
13 | 16 | #' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.
|
14 | 17 | #' @param forest_type Which forest to use from `model_object`.
|
15 |
| -#' Valid inputs depend on the model type, and whether or not a |
| 18 | +#' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model. |
16 | 19 | #'
|
17 | 20 | #' **1. BART**
|
18 | 21 | #'
|
@@ -88,6 +91,89 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type, fore
|
88 | 91 | return(leaf_ind_matrix)
|
89 | 92 | }
|
90 | 93 |
|
| 94 | +#' Compute vector of forest leaf scale parameters |
| 95 | +#' |
| 96 | +#' @description Return each forest's leaf node scale parameters. |
| 97 | +#' |
| 98 | +#' If leaf scale is not sampled for the forest in question, throws an error that the |
| 99 | +#' leaf model does not have a stochastic scale parameter. |
| 100 | +#' |
| 101 | +#' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample |
| 102 | +#' @param forest_type Which forest to use from `model_object`. |
| 103 | +#' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model. |
| 104 | +#' |
| 105 | +#' **1. BART** |
| 106 | +#' |
| 107 | +#' - `'mean'`: Extracts leaf indices for the mean forest |
| 108 | +#' - `'variance'`: Extracts leaf indices for the variance forest |
| 109 | +#' |
| 110 | +#' **2. BCF** |
| 111 | +#' |
| 112 | +#' - `'prognostic'`: Extracts leaf indices for the prognostic forest |
| 113 | +#' - `'treatment'`: Extracts leaf indices for the treatment effect forest |
| 114 | +#' - `'variance'`: Extracts leaf indices for the variance forest |
| 115 | +#' |
| 116 | +#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, |
| 117 | +#' this function will return leaf indices for every sample of a forest. |
| 118 | +#' This function uses 1-indexing, so the first forest sample corresponds to `forest_num = 1`, and so on. |
| 119 | +#' @return Vector of size `length(forest_inds)` with the leaf scale parameter for each requested forest. |
| 120 | +#' @export |
| 121 | +computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) { |
| 122 | + # Extract relevant forest container |
| 123 | + stopifnot(class(model_object) %in% c("bartmodel", "bcf")) |
| 124 | + model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf") |
| 125 | + if (model_type == "bart") { |
| 126 | + stopifnot(forest_type %in% c("mean", "variance")) |
| 127 | + if (forest_type=="mean") { |
| 128 | + if (!model_object$model_params$include_mean_forest) { |
| 129 | + stop("Mean forest was not sampled in the bart model provided") |
| 130 | + } |
| 131 | + if (model_object$model_params$sample_sigma_leaf == F) { |
| 132 | + stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided") |
| 133 | + } |
| 134 | + leaf_scale_vector <- model_object$sigma2_leaf_samples |
| 135 | + } else if (forest_type=="variance") { |
| 136 | + if (!model_object$model_params$include_variance_forest) { |
| 137 | + stop("Variance forest was not sampled in the bart model provided") |
| 138 | + } |
| 139 | + stop("Leaf scale parameter was not sampled for the variance forest in the bart model provided") |
| 140 | + } |
| 141 | + } else { |
| 142 | + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) |
| 143 | + if (forest_type=="prognostic") { |
| 144 | + if (model_object$model_params$sample_sigma_leaf_mu == F) { |
| 145 | + stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided") |
| 146 | + } |
| 147 | + leaf_scale_vector <- model_object$sigma_leaf_mu_samples |
| 148 | + } else if (forest_type=="treatment") { |
| 149 | + if (model_object$model_params$sample_sigma_leaf_tau == F) { |
| 150 | + stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided") |
| 151 | + } |
| 152 | + leaf_scale_vector <- model_object$sigma_leaf_tau_samples |
| 153 | + } else if (forest_type=="variance") { |
| 154 | + if (!model_object$model_params$include_variance_forest) { |
| 155 | + stop("Variance forest was not sampled in the bcf model provided") |
| 156 | + } |
| 157 | + stop("Leaf scale parameter was not sampled for the variance forest in the bcf model provided") |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + # Preprocess forest indices |
| 162 | + num_forests <- forest_container$num_samples() |
| 163 | + if (is.null(forest_inds)) { |
| 164 | + forest_inds <- as.integer(1:num_forests) |
| 165 | + } else { |
| 166 | + stopifnot(all(forest_inds <= num_forests)) |
| 167 | + stopifnot(all(forest_inds >= 1)) |
| 168 | + forest_inds <- as.integer(forest_inds) |
| 169 | + } |
| 170 | + |
| 171 | + # Gather leaf scale parameters |
| 172 | + leaf_scale_params <- leaf_scale_vector[forest_inds] |
| 173 | + |
| 174 | + return(leaf_scale_params) |
| 175 | +} |
| 176 | + |
91 | 177 | #' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container.
|
92 | 178 | #'
|
93 | 179 | #' @param model_object Object of type `bartmodel` or `bcf` corresponding to a BART / BCF model with at least one forest sample
|
|
0 commit comments