Skip to content

Commit 7a8a87c

Browse files
authored
Merge pull request #132 from StochasticTree/initial-cran-release-prep
Initial CRAN release preparation
2 parents f1e9772 + 12299e1 commit 7a8a87c

File tree

114 files changed

+3479
-1624
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+3479
-1624
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
^.*\.Rproj$
22
^\.Rproj\.user$
3+
^cran-comments\.md$

.github/workflows/r-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939

4040
- name: Create a CRAN-ready version of the R package
4141
run: |
42-
Rscript cran-bootstrap.R 0
42+
Rscript cran-bootstrap.R 0 0
4343
4444
- uses: r-lib/actions/check-r-package@v2
4545
with:

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
Package: stochtree
22
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
3-
Version: 0.0.1
3+
Version: 0.1.0
44
Authors@R:
55
c(
66
person("Drew", "Herren", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),
77
person("Richard", "Hahn", role = "aut"),
88
person("Jared", "Murray", role = "aut"),
99
person("Carlos", "Carvalho", role = "aut"),
10-
person("Jingyu", "He", role = "aut")
10+
person("Jingyu", "He", role = "aut"),
11+
person("stochtree contributors", role = c("cph"))
1112
)
1213
Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference.
1314
License: MIT + file LICENSE

NAMESPACE

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,37 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(getRandomEffectSamples,bartmodel)
4-
S3method(getRandomEffectSamples,bcf)
4+
S3method(getRandomEffectSamples,bcfmodel)
55
S3method(predict,bartmodel)
6-
S3method(predict,bcf)
6+
S3method(predict,bcfmodel)
77
export(bart)
88
export(bcf)
9-
export(calibrate_inverse_gamma_error_variance)
9+
export(calibrateInverseGammaErrorVariance)
1010
export(computeForestLeafIndices)
1111
export(computeForestLeafVariances)
12-
export(computeMaxLeafIndex)
13-
export(convertBARTModelToJson)
14-
export(convertBCFModelToJson)
12+
export(computeForestMaxLeafIndex)
1513
export(convertPreprocessorToJson)
1614
export(createBARTModelFromCombinedJson)
1715
export(createBARTModelFromCombinedJsonString)
1816
export(createBARTModelFromJson)
1917
export(createBARTModelFromJsonFile)
2018
export(createBARTModelFromJsonString)
19+
export(createBCFModelFromCombinedJson)
2120
export(createBCFModelFromCombinedJsonString)
2221
export(createBCFModelFromJson)
2322
export(createBCFModelFromJsonFile)
2423
export(createBCFModelFromJsonString)
2524
export(createCppJson)
2625
export(createCppJsonFile)
2726
export(createCppJsonString)
27+
export(createCppRNG)
2828
export(createForest)
29-
export(createForestContainer)
30-
export(createForestCovariates)
31-
export(createForestCovariatesFromMetadata)
3229
export(createForestDataset)
3330
export(createForestModel)
31+
export(createForestSamples)
3432
export(createOutcome)
3533
export(createPreprocessorFromJson)
3634
export(createPreprocessorFromJsonString)
37-
export(createRNG)
3835
export(createRandomEffectSamples)
3936
export(createRandomEffectsDataset)
4037
export(createRandomEffectsModel)
@@ -48,35 +45,28 @@ export(loadRandomEffectSamplesCombinedJsonString)
4845
export(loadRandomEffectSamplesJson)
4946
export(loadScalarJson)
5047
export(loadVectorJson)
51-
export(oneHotEncode)
52-
export(oneHotInitializeAndEncode)
53-
export(orderedCatInitializeAndPreprocess)
54-
export(orderedCatPreprocess)
55-
export(preprocessParams)
5648
export(preprocessPredictionData)
57-
export(preprocessPredictionDataFrame)
58-
export(preprocessPredictionMatrix)
5949
export(preprocessTrainData)
60-
export(preprocessTrainDataFrame)
61-
export(preprocessTrainMatrix)
6250
export(resetActiveForest)
6351
export(resetForestModel)
6452
export(resetRandomEffectsModel)
6553
export(resetRandomEffectsTracker)
66-
export(rootResetActiveForest)
6754
export(rootResetRandomEffectsModel)
6855
export(rootResetRandomEffectsTracker)
69-
export(sample_sigma2_one_iteration)
70-
export(sample_tau_one_iteration)
56+
export(sampleGlobalErrorVarianceOneIteration)
57+
export(sampleLeafVarianceOneIteration)
58+
export(saveBARTModelToJson)
7159
export(saveBARTModelToJsonFile)
7260
export(saveBARTModelToJsonString)
61+
export(saveBCFModelToJson)
7362
export(saveBCFModelToJsonFile)
7463
export(saveBCFModelToJsonString)
7564
export(savePreprocessorToJsonString)
7665
importFrom(R6,R6Class)
7766
importFrom(stats,coef)
7867
importFrom(stats,lm)
7968
importFrom(stats,model.matrix)
69+
importFrom(stats,predict)
8070
importFrom(stats,qgamma)
8171
importFrom(stats,resid)
8272
importFrom(stats,rnorm)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# stochtree 0.1.0
2+
3+
* Initial CRAN submission.

R/bart.R

Lines changed: 159 additions & 216 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 437 additions & 222 deletions
Large diffs are not rendered by default.

R/calibration.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#' X <- matrix(runif(n*p), ncol = p)
1919
#' y <- 10*X[,1] - 20*X[,2] + rnorm(n)
2020
#' nu <- 3
21-
#' lambda <- calibrate_inverse_gamma_error_variance(y, X, nu = nu)
21+
#' lambda <- calibrateInverseGammaErrorVariance(y, X, nu = nu)
2222
#' sigma2hat <- mean(resid(lm(y~X))^2)
2323
#' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat)
24-
calibrate_inverse_gamma_error_variance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) {
24+
calibrateInverseGammaErrorVariance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) {
2525
# Compute regression basis
2626
if (!is.null(W)) basis <- cbind(X, W)
2727
else basis <- X

R/cpp11.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ json_load_forest_container_cpp <- function(forest_samples, json_filename) {
292292
invisible(.Call(`_stochtree_json_load_forest_container_cpp`, forest_samples, json_filename))
293293
}
294294

295-
output_dimension_forest_container_cpp <- function(forest_samples) {
296-
.Call(`_stochtree_output_dimension_forest_container_cpp`, forest_samples)
295+
leaf_dimension_forest_container_cpp <- function(forest_samples) {
296+
.Call(`_stochtree_leaf_dimension_forest_container_cpp`, forest_samples)
297297
}
298298

299299
is_leaf_constant_forest_container_cpp <- function(forest_samples) {
@@ -464,8 +464,8 @@ predict_raw_active_forest_cpp <- function(active_forest, dataset) {
464464
.Call(`_stochtree_predict_raw_active_forest_cpp`, active_forest, dataset)
465465
}
466466

467-
output_dimension_active_forest_cpp <- function(active_forest) {
468-
.Call(`_stochtree_output_dimension_active_forest_cpp`, active_forest)
467+
leaf_dimension_active_forest_cpp <- function(active_forest) {
468+
.Call(`_stochtree_leaf_dimension_active_forest_cpp`, active_forest)
469469
}
470470

471471
average_max_depth_active_forest_cpp <- function(active_forest) {

R/data.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ RandomEffectsDataset <- R6::R6Class(
228228
#'
229229
#' @return `ForestDataset` object
230230
#' @export
231+
#'
232+
#' @examples
233+
#' covariate_matrix <- matrix(runif(10*100), ncol = 10)
234+
#' basis_matrix <- matrix(rnorm(3*100), ncol = 3)
235+
#' weight_vector <- rnorm(100)
236+
#' forest_dataset <- createForestDataset(covariate_matrix)
237+
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix)
238+
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector)
231239
createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){
232240
return(invisible((
233241
ForestDataset$new(covariates, basis, variance_weights)
@@ -240,6 +248,11 @@ createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){
240248
#'
241249
#' @return `Outcome` object
242250
#' @export
251+
#'
252+
#' @examples
253+
#' X <- matrix(runif(10*100), ncol = 10)
254+
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
255+
#' outcome <- createOutcome(y)
243256
createOutcome <- function(outcome){
244257
return(invisible((
245258
Outcome$new(outcome)
@@ -254,6 +267,13 @@ createOutcome <- function(outcome){
254267
#'
255268
#' @return `RandomEffectsDataset` object
256269
#' @export
270+
#'
271+
#' @examples
272+
#' rfx_group_ids <- sample(1:2, size = 100, replace = TRUE)
273+
#' rfx_basis <- matrix(rnorm(3*100), ncol = 3)
274+
#' weight_vector <- rnorm(100)
275+
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis)
276+
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector)
257277
createRandomEffectsDataset <- function(group_labels, basis, variance_weights=NULL){
258278
return(invisible((
259279
RandomEffectsDataset$new(group_labels, basis, variance_weights)

0 commit comments

Comments
 (0)