Skip to content

Commit c8765c7

Browse files
topepohfrick
andauthored
add partykit wrappers (#697)
* add partykit wrappers * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> * changes based on reviewer feedback Co-authored-by: Hannah Frick <[email protected]>
1 parent 152caec commit c8765c7

File tree

6 files changed

+207
-1
lines changed

6 files changed

+207
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 0.2.1.9000
3+
Version: 0.2.1.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ export(bag_tree)
171171
export(bart)
172172
export(bartMachine_interval_calc)
173173
export(boost_tree)
174+
export(cforest_train)
174175
export(check_empty_ellipse)
175176
export(check_final_param)
176177
export(check_model_doesnt_exist)
177178
export(check_model_exists)
178179
export(contr_one_hot)
179180
export(control_parsnip)
180181
export(convert_stan_interval)
182+
export(ctree_train)
181183
export(cubist_rules)
182184
export(dbart_predict_calc)
183185
export(decision_tree)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Added `ctree_train()` and `cforest_train()` wrappers for the functions in the partykit package. Engines for these will be added to other parsnip extension packages.
4+
35
* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688).
46

57

R/partykit.R

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#' A wrapper function for conditional inference tree models
2+
#'
3+
#' These functions are slightly different APIs for [partykit::ctree()] and
4+
#' [partykit::cforest()] that have several important arguments as top-level
5+
#' arguments (as opposed to being specified in [partykit::ctree_control()]).
6+
#' @param formula A symbolic description of the model to be fit.
7+
#' @param data A data frame containing the variables in the model.
8+
#' @param teststat A character specifying the type of the test statistic to be
9+
#' applied.
10+
#' @param testtype A character specifying how to compute the distribution of
11+
#' the test statistic.
12+
#' @param mincriterion The value of the test statistic (for \code{testtype ==
13+
#' "Teststatistic"}), or 1 - p-value (for other values of \code{testtype}) that
14+
#' must be exceeded in order to implement a split.
15+
#' @param minsplit The minimum sum of weights in a node in order to be
16+
#' considered for splitting.
17+
#' @param maxdepth maximum depth of the tree. The default \code{maxdepth = Inf}
18+
#' means that no restrictions are applied to tree sizes.
19+
#' @param mtry Number of input variables randomly sampled as candidates at each
20+
#' node for random forest like algorithms. The default \code{mtry = Inf} means
21+
#' that no random selection takes place.
22+
#' @param ntree Number of trees to grow in a forest.
23+
#' @param ... Other options to pass to [partykit::ctree()] or [partykit::cforest()].
24+
#' @return An object of class `party` (for `ctree`) or `cforest`.
25+
#' @examples
26+
#' if (rlang::is_installed(c("modeldata", "partykit"))) {
27+
#' data(bivariate, package = "modeldata")
28+
#' ctree_train(Class ~ ., data = bivariate_train)
29+
#' ctree_train(Class ~ ., data = bivariate_train, maxdepth = 1)
30+
#' }
31+
#' @export
32+
ctree_train <-
33+
function(formula,
34+
data,
35+
minsplit = 20L,
36+
maxdepth = Inf,
37+
teststat = "quadratic",
38+
testtype = "Bonferroni",
39+
mincriterion = 0.95,
40+
...) {
41+
rlang::check_installed("partykit")
42+
opts <- rlang::list2(...)
43+
44+
if (any(names(opts) == "control")) {
45+
opts$control$minsplit <- minsplit
46+
opts$control$maxdepth <- maxdepth
47+
opts$control$teststat <- teststat
48+
opts$control$testtype <- testtype
49+
opts$control$mincriterion <- mincriterion
50+
} else {
51+
opts$control <-
52+
rlang::call2(
53+
"ctree_control",
54+
.ns = "partykit",
55+
!!!list(
56+
minsplit = minsplit,
57+
maxdepth = maxdepth,
58+
teststat = teststat,
59+
testtype = testtype,
60+
mincriterion = mincriterion
61+
)
62+
)
63+
}
64+
65+
tree_call <-
66+
rlang::call2(
67+
"ctree",
68+
.ns = "partykit",
69+
formula = rlang::expr(formula),
70+
data = rlang::expr(data),
71+
!!!opts
72+
)
73+
rlang::eval_tidy(tree_call)
74+
}
75+
76+
#' @rdname ctree_train
77+
#' @export
78+
cforest_train <-
79+
function(formula,
80+
data,
81+
minsplit = 20L,
82+
maxdepth = Inf,
83+
teststat = "quadratic",
84+
testtype = "Univariate",
85+
mincriterion = 0,
86+
mtry = ceiling(sqrt(ncol(data) - 1)),
87+
ntree = 500L,
88+
...) {
89+
rlang::check_installed("partykit")
90+
force(mtry)
91+
opts <- rlang::list2(...)
92+
93+
if (any(names(opts) == "control")) {
94+
opts$control$minsplit <- minsplit
95+
opts$control$maxdepth <- maxdepth
96+
opts$control$teststat <- teststat
97+
opts$control$testtype <- testtype
98+
opts$control$logmincriterion <- log(mincriterion)
99+
opts$control$mincriterion <- mincriterion
100+
} else {
101+
opts$control <-
102+
rlang::call2(
103+
"ctree_control",
104+
.ns = "partykit",
105+
!!!list(
106+
minsplit = minsplit,
107+
maxdepth = maxdepth,
108+
teststat = teststat,
109+
testtype = testtype,
110+
mincriterion = mincriterion,
111+
saveinfo = FALSE
112+
)
113+
)
114+
}
115+
opts$mtry <- mtry
116+
opts$ntree <- ntree
117+
forest_call <-
118+
rlang::call2(
119+
"cforest",
120+
.ns = "partykit",
121+
formula = rlang::expr(formula),
122+
data = rlang::expr(data),
123+
!!!opts
124+
)
125+
rlang::eval_tidy(forest_call)
126+
}

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ reference:
8787
- tidy.model_fit
8888
- translate
8989
- starts_with("update")
90+
- matches("_train")
9091

9192
- title: Developer tools
9293
contents:

man/ctree_train.Rd

Lines changed: 75 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)