Skip to content

Commit 58a74b4

Browse files
authored
Merge pull request #342 from tidymodels/0-1-2-rc
0.1.2 release candidate
2 parents 0d80715 + 81957b3 commit 58a74b4

File tree

9 files changed

+76
-61
lines changed

9 files changed

+76
-61
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.1.1.9000
2+
Version: 0.1.2
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# parsnip (development version)
1+
# parsnip 0.1.2
22

33
## Breaking Changes
44

R/boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
476476
#' @param weights An optional numeric vector of case weights. Note
477477
#' that the data used for the case weights will not be used as a
478478
#' splitting variable in the model (see
479-
#' \url{http://www.rulequest.com/see5-win.html#CASEWEIGHT} for
479+
#' \url{http://www.rulequest.com/see5-win.html} for
480480
#' Quinlan's notes on case weights).
481481
#' @param minCases An integer for the smallest number of samples
482482
#' that must be put in at least two of the splits.

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ rand_forest(mtry = 10, trees = 2000) %>%
144144
#> Ranger result
145145
#>
146146
#> Call:
147-
#> ranger::ranger(formula = formula, data = data, mtry = ~10, num.trees = ~2000, importance = ~"impurity", num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1))
147+
#> ranger::ranger(formula = mpg ~ ., data = data, mtry = ~10, num.trees = ~2000, importance = ~"impurity", num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1))
148148
#>
149149
#> Type: Regression
150150
#> Number of trees: 2000
@@ -154,8 +154,8 @@ rand_forest(mtry = 10, trees = 2000) %>%
154154
#> Target node size: 5
155155
#> Variable importance mode: impurity
156156
#> Splitrule: variance
157-
#> OOB prediction error (MSE): 5.911312
158-
#> R squared (OOB): 0.837262
157+
#> OOB prediction error (MSE): 5.699772
158+
#> R squared (OOB): 0.8430857
159159
```
160160

161161
A list of all `parsnip` models across different CRAN packages can be

man/C5.0_train.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/contr_one_hot.Rd

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_linear_reg_glmnet.R

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,11 @@ test_that('glmnet prediction, single lambda', {
6868
y = hpc$input_fields
6969
)
7070

71-
uni_pred <- c(5.05125589060219, 4.86977761622526, 4.90912345599309, 4.93931874108359,
72-
5.08755154547758)
71+
# glmn_mod <- glmnet::glmnet(x = as.matrix(hpc[, num_pred]), y = hpc$input_fields,
72+
# alpha = .3, nlambda = 15)
73+
74+
uni_pred <- c(640.599944271351, 196.646976529848, 186.279646400216, 194.673852228774,
75+
198.126819755653)
7376

7477
expect_equal(uni_pred, predict(res_xy, hpc[1:5, num_pred])$.pred, tolerance = 0.0001)
7578

@@ -80,8 +83,8 @@ test_that('glmnet prediction, single lambda', {
8083
control = ctrl
8184
)
8285

83-
form_pred <- c(5.23960117346944, 5.08769210344022, 5.15129212608077, 5.12000510716518,
84-
5.26736239856889)
86+
form_pred <- c(570.504089227118, 162.413061474088, 167.022896537861, 157.609071878082,
87+
165.887783741483)
8588

8689
expect_equal(form_pred, predict(res_form, hpc[1:5,])$.pred, tolerance = 0.0001)
8790
})
@@ -118,16 +121,16 @@ test_that('glmnet prediction, multiple lambda', {
118121
mult_pred <-
119122
tibble::tribble(
120123
~penalty, ~.pred,
121-
0.01, 5.01352459498158,
122-
0.1, 5.05124049139868,
123-
0.01, 4.71767499960808,
124-
0.1, 4.87103404621362,
125-
0.01, 4.7791916685127,
126-
0.1, 4.91028250633598,
127-
0.01, 4.83366808792755,
128-
0.1, 4.9399094532023,
129-
0.01, 5.07269451405628,
130-
0.1, 5.08728178043569
124+
0.01, 639.672880668187,
125+
0.1, 639.672880668187,
126+
0.01, 197.744613311359,
127+
0.1, 197.744613311359,
128+
0.01, 187.737940787615,
129+
0.1, 187.737940787615,
130+
0.01, 195.780487678662,
131+
0.1, 195.780487678662,
132+
0.01, 199.217707535882,
133+
0.1, 199.217707535882
131134
)
132135

133136
expect_equal(
@@ -163,16 +166,16 @@ test_that('glmnet prediction, multiple lambda', {
163166
form_pred <-
164167
tibble::tribble(
165168
~penalty, ~.pred,
166-
0.01, 5.09237402805557,
167-
0.1, 5.24228948237804,
168-
0.01, 4.75071416991856,
169-
0.1, 5.09448280355765,
170-
0.01, 4.89375747015535,
171-
0.1, 5.15636527125752,
172-
0.01, 4.82338959520112,
173-
0.1, 5.12592317615935,
174-
0.01, 5.15481201301174,
175-
0.1, 5.26930099973607
169+
0.01, 570.474473760044,
170+
0.1, 570.474473760044,
171+
0.01, 164.040104978709,
172+
0.1, 164.040104978709,
173+
0.01, 168.709676954287,
174+
0.1, 168.709676954287,
175+
0.01, 159.173862504055,
176+
0.1, 159.173862504055,
177+
0.01, 167.559854709074,
178+
0.1, 167.559854709074
176179
)
177180

178181
expect_equal(
@@ -190,7 +193,7 @@ test_that('glmnet prediction, all lambda', {
190193
skip_if(run_glmnet)
191194

192195
hpc_all <- linear_reg(mixture = .3) %>%
193-
set_engine("glmnet")
196+
set_engine("glmnet", nlambda = 7)
194197

195198
res_xy <- fit_xy(
196199
hpc_all,
@@ -202,7 +205,7 @@ test_that('glmnet prediction, all lambda', {
202205
all_pred <- predict(res_xy$fit, newx = as.matrix(hpc[1:5, num_pred]))
203206
all_pred <- stack(as.data.frame(all_pred))
204207
all_pred$penalty <- rep(res_xy$fit$lambda, each = 5)
205-
all_pred$rows <- rep(1:5, 2)
208+
all_pred$rows <- rep(1:5, length(res_xy$fit$lambda))
206209
all_pred <- all_pred[order(all_pred$rows, all_pred$penalty), ]
207210
all_pred <- all_pred[, c("penalty", "values")]
208211
names(all_pred) <- c("penalty", ".pred")
@@ -223,7 +226,7 @@ test_that('glmnet prediction, all lambda', {
223226
form_pred <- predict(res_form$fit, newx = form_mat)
224227
form_pred <- stack(as.data.frame(form_pred))
225228
form_pred$penalty <- rep(res_form$fit$lambda, each = 5)
226-
form_pred$rows <- rep(1:5, 2)
229+
form_pred$rows <- rep(1:5, length(res_form$fit$lambda))
227230
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
228231
form_pred <- form_pred[, c("penalty", "values")]
229232
names(form_pred) <- c("penalty", ".pred")

tests/testthat/test_linear_reg_stan.R

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ library(parsnip)
33
library(rlang)
44

55
source(test_path("helper-objects.R"))
6-
hpc <- hpc_data[1:150, c(2:5, 8)]
6+
hpc <- hpc_data[, c(2:5, 8)]
77

88
# ------------------------------------------------------------------------------
99

@@ -62,10 +62,10 @@ test_that('stan prediction', {
6262
skip_if_not_installed("rstanarm")
6363
skip_on_cran()
6464

65-
uni_pred <- c(5.01531691055198, 4.6896592504705, 4.74907435900005, 4.82563873798984,
66-
5.08044844256827)
67-
inl_pred <- c(3.47062722437493, 3.38380776677489, 3.29336980560884, 3.24669710332179,
68-
3.42765162180813)
65+
uni_pred <- c(1691.46306020449, 1494.27323520418, 1522.36011539284, 1493.39683598195,
66+
1494.93053462084)
67+
inl_pred <- c(429.164145548939, 256.32488428038, 254.949927688403, 255.007333947447,
68+
255.336665165556)
6969

7070
res_xy <- fit_xy(
7171
linear_reg() %>%
@@ -99,27 +99,29 @@ test_that('stan intervals', {
9999
control = quiet_ctrl
100100
)
101101

102+
set.seed(1231)
102103
confidence_parsnip <-
103104
predict(res_xy,
104105
new_data = hpc[1:5,],
105106
type = "conf_int",
106107
level = 0.93)
107108

109+
set.seed(1231)
108110
prediction_parsnip <-
109111
predict(res_xy,
110112
new_data = hpc[1:5,],
111113
type = "pred_int",
112114
level = 0.93)
113115

114-
ci_lower <- c(4.93164991101342, 4.60197941230393, 4.6671442757811, 4.74402724639963,
115-
4.99248110476701)
116-
ci_upper <- c(5.1002837047058, 4.77617561853506, 4.83183673602725, 4.90844811805409,
117-
5.16979395659009)
116+
ci_lower <- c(1577.25718753727, 1382.58210286254, 1399.96490471468, 1381.56774986889,
117+
1383.25519963864)
118+
ci_upper <- c(1809.28331613624, 1609.11912475981, 1646.44852457781, 1608.3327281785,
119+
1609.4796390366)
118120

119-
pi_lower <- c(4.43202758985944, 4.09957733046886, 4.17664779714598, 4.24948546338885,
120-
4.50058914781073)
121-
pi_upper <- c(5.59783267637042, 5.25976504318669, 5.33296516452929, 5.41050668003565,
122-
5.66355828140989)
121+
pi_lower <- c(-4960.33135373564, -5123.82860109357, -5063.60881734505, -5341.21637448872,
122+
-5184.63627366821)
123+
pi_upper <- c(8345.56815544477, 7954.98392035813, 7890.10036321417, 7970.64062851536,
124+
8247.10241974192)
123125

124126
expect_equivalent(confidence_parsnip$.pred_lower, ci_lower, tolerance = 1e-2)
125127
expect_equivalent(confidence_parsnip$.pred_upper, ci_upper, tolerance = 1e-2)

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ library(testthat)
22
library(parsnip)
33
library(rlang)
44
library(tibble)
5+
library(dplyr)
56

67
# ------------------------------------------------------------------------------
78

89
context("multinom regression execution with glmnet")
910
source(test_path("helper-objects.R"))
10-
hpc <- hpc_data[1:150, c(2:5, 8)]
11+
hpc <- hpc_data[, c(2:5, 8)]
1112

1213
rows <- c(1, 51, 101)
1314

@@ -117,10 +118,14 @@ test_that('glmnet probabilities, mulitiple lambda', {
117118
names(mult_pred) <- NULL
118119
mult_pred <- tibble(.pred = mult_pred)
119120

120-
expect_equal(
121-
mult_pred$.pred,
122-
multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams, type = "prob")$.pred
123-
)
121+
multi_pred_res <- multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams, type = "prob")
122+
123+
for (i in seq_along(multi_pred_res$.pred)) {
124+
expect_equal(
125+
mult_pred %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
126+
multi_pred_res %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
127+
)
128+
}
124129

125130
mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)],
126131
levels = xy_fit$lvl)
@@ -134,10 +139,14 @@ test_that('glmnet probabilities, mulitiple lambda', {
134139
names(mult_class) <- NULL
135140
mult_class <- tibble(.pred = mult_class)
136141

137-
expect_equal(
138-
mult_class$.pred,
139-
multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams)$.pred
140-
)
142+
mult_class_res <- multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams)
143+
144+
for (i in seq_along(mult_class_res$.pred)) {
145+
expect_equal(
146+
mult_class %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
147+
mult_class_res %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
148+
)
149+
}
141150

142151
expect_error(
143152
multi_predict(xy_fit, newdata = hpc[rows, 1:4], penalty = lams),
@@ -157,7 +166,7 @@ test_that("class predictions are factors with all levels", {
157166
skip_if(run_glmnet)
158167

159168
basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(class ~ ., data = hpc)
160-
nd <- hpc[hpc$class == "setosa", ]
169+
nd <- hpc[hpc$class == "VF", ]
161170
yhat <- predict(basic, new_data = nd, penalty = .1)
162171
yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred
163172
expect_is(yhat_multi[[1]]$.pred_class, "factor")

0 commit comments

Comments
 (0)