Skip to content

Commit a871609

Browse files
committed
model-specific changes for #315
1 parent 3e46c0d commit a871609

File tree

9 files changed

+93
-50
lines changed

9 files changed

+93
-50
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ export(predict_quantile.model_fit)
142142
export(predict_raw)
143143
export(predict_raw.model_fit)
144144
export(rand_forest)
145+
export(repair_call)
145146
export(rpart_train)
146147
export(set_args)
147148
export(set_dependency)

R/mars_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ set_fit(
5252
eng = "earth",
5353
mode = "classification",
5454
value = list(
55-
interface = "data.frame",
56-
protect = c("x", "y", "weights"),
55+
interface = "formula",
56+
protect = c("formula", "data", "weights"),
5757
func = c(pkg = "earth", fun = "earth"),
5858
defaults = list(keepxy = TRUE)
5959
)

R/svm_poly_data.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ set_fit(
4949
eng = "kernlab",
5050
mode = "regression",
5151
value = list(
52-
interface = "matrix",
53-
protect = c("x", "y"),
52+
interface = "formula",
53+
data = c(formula = "x", data = "data"),
54+
protect = c("x", "data"),
5455
func = c(pkg = "kernlab", fun = "ksvm"),
5556
defaults = list(kernel = "polydot")
5657
)
@@ -61,8 +62,9 @@ set_fit(
6162
eng = "kernlab",
6263
mode = "classification",
6364
value = list(
64-
interface = "matrix",
65-
protect = c("x", "y"),
65+
interface = "formula",
66+
data = c(formula = "x", data = "data"),
67+
protect = c("x", "data"),
6668
func = c(pkg = "kernlab", fun = "ksvm"),
6769
defaults = list(kernel = "polydot")
6870
)

man/mars.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/repair_call.Rd

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/svm_poly.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mars.R

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ library(rlang)
55
# ------------------------------------------------------------------------------
66

77
context("mars tests")
8-
source("helpers.R")
9-
source("helper-objects.R")
8+
source(test_path("helpers.R"))
9+
source(test_path("helper-objects.R"))
1010

1111
# ------------------------------------------------------------------------------
1212

@@ -15,8 +15,8 @@ test_that('primary arguments', {
1515
basic_mars <- translate(basic %>% set_engine("earth"))
1616
expect_equal(basic_mars$method$fit$args,
1717
list(
18-
x = expr(missing_arg()),
19-
y = expr(missing_arg()),
18+
formula = expr(missing_arg()),
19+
data = expr(missing_arg()),
2020
weights = expr(missing_arg()),
2121
keepxy = TRUE
2222
)
@@ -26,8 +26,8 @@ test_that('primary arguments', {
2626
num_terms_mars <- translate(num_terms %>% set_engine("earth"))
2727
expect_equal(num_terms_mars$method$fit$args,
2828
list(
29-
x = expr(missing_arg()),
30-
y = expr(missing_arg()),
29+
formula = expr(missing_arg()),
30+
data = expr(missing_arg()),
3131
weights = expr(missing_arg()),
3232
nprune = new_empty_quosure(4),
3333
glm = rlang::quo(list(family = stats::binomial)),
@@ -39,8 +39,8 @@ test_that('primary arguments', {
3939
prod_degree_mars <- translate(prod_degree %>% set_engine("earth"))
4040
expect_equal(prod_degree_mars$method$fit$args,
4141
list(
42-
x = expr(missing_arg()),
43-
y = expr(missing_arg()),
42+
formula = expr(missing_arg()),
43+
data = expr(missing_arg()),
4444
weights = expr(missing_arg()),
4545
degree = new_empty_quosure(1),
4646
keepxy = TRUE
@@ -51,8 +51,8 @@ test_that('primary arguments', {
5151
prune_method_v_mars <- translate(prune_method_v %>% set_engine("earth"))
5252
expect_equal(prune_method_v_mars$method$fit$args,
5353
list(
54-
x = expr(missing_arg()),
55-
y = expr(missing_arg()),
54+
formula = expr(missing_arg()),
55+
data = expr(missing_arg()),
5656
weights = expr(missing_arg()),
5757
pmethod = new_empty_quosure(varying()),
5858
keepxy = TRUE
@@ -64,8 +64,8 @@ test_that('engine arguments', {
6464
mars_keep <- mars(mode = "regression")
6565
expect_equal(translate(mars_keep %>% set_engine("earth", keepxy = FALSE))$method$fit$args,
6666
list(
67-
x = expr(missing_arg()),
68-
y = expr(missing_arg()),
67+
formula = expr(missing_arg()),
68+
data = expr(missing_arg()),
6969
weights = expr(missing_arg()),
7070
keepxy = new_empty_quosure(FALSE)
7171
)
@@ -107,16 +107,9 @@ test_that('updating', {
107107
})
108108

109109
test_that('bad input', {
110-
# expect_error(mars(prod_degree = -1))
111-
# expect_error(mars(num_terms = -1))
112110
expect_error(translate(mars() %>% set_engine("wat?")))
113111
expect_error(translate(mars(mode = "regression") %>% set_engine()))
114112
expect_error(translate(mars(formula = y ~ x)))
115-
expect_warning(
116-
translate(
117-
mars(mode = "regression") %>% set_engine("earth", x = iris[,1:3], y = iris$Species)
118-
)
119-
)
120113
})
121114

122115
# ------------------------------------------------------------------------------
@@ -154,13 +147,16 @@ test_that('mars execution', {
154147
expect_true(has_multi_predict(res))
155148
expect_equal(multi_predict_args(res), "num_terms")
156149

157-
expect_error(
158-
res <- fit(
159-
iris_basic,
160-
iris_bad_form,
161-
data = iris,
162-
control = ctrl
163-
)
150+
expect_message(
151+
expect_error(
152+
res <- fit(
153+
iris_basic,
154+
iris_bad_form,
155+
data = iris,
156+
control = ctrl
157+
)
158+
),
159+
"Timing stopped"
164160
)
165161

166162
## multivariate y
@@ -259,22 +255,23 @@ test_that('submodel prediction', {
259255
mp_res <- do.call("rbind", mp_res$.pred)
260256
expect_equal(mp_res[[".pred"]], pruned_reg_pred)
261257

258+
full_churn <- wa_churn[complete.cases(wa_churn), ]
262259
vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges")
263260
class_fit <-
264261
mars(mode = "classification", prune_method = "none") %>%
265262
set_engine("earth", keepxy = TRUE) %>%
266263
fit(churn ~ .,
267-
data = wa_churn[-(1:4), c("churn", vars)])
264+
data = full_churn[-(1:4), c("churn", vars)])
268265

269266
cls_fit <- class_fit$fit
270267
cls_fit$call[["pmethod"]] <- eval_tidy(cls_fit$call[["pmethod"]])
271268
cls_fit$call[["keepxy"]] <- eval_tidy(cls_fit$call[["keepxy"]])
272269
cls_fit$call[["glm"]] <- eval_tidy(cls_fit$call[["glm"]])
273270

274271
pruned_cls <- update(cls_fit, nprune = 5)
275-
pruned_cls_pred <- predict(pruned_cls, wa_churn[1:4, vars], type = "response")[,1]
272+
pruned_cls_pred <- predict(pruned_cls, full_churn[1:4, vars], type = "response")[,1]
276273

277-
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], num_terms = 5, type = "prob")
274+
mp_res <- multi_predict(class_fit, new_data = full_churn[1:4, vars], num_terms = 5, type = "prob")
278275
mp_res <- do.call("rbind", mp_res$.pred)
279276
expect_equal(mp_res[[".pred_No"]], pruned_cls_pred)
280277

tests/testthat/test_svm_poly.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test_that('primary arguments', {
1818
object = basic_kernlab$method$fit$args,
1919
expected = list(
2020
x = expr(missing_arg()),
21-
y = expr(missing_arg()),
21+
data = expr(missing_arg()),
2222
kernel = "polydot"
2323
)
2424
)
@@ -32,7 +32,7 @@ test_that('primary arguments', {
3232
object = degree_kernlab$method$fit$args,
3333
expected = list(
3434
x = expr(missing_arg()),
35-
y = expr(missing_arg()),
35+
data = expr(missing_arg()),
3636
kernel = "polydot",
3737
kpar = degree_obj
3838
)
@@ -48,7 +48,7 @@ test_that('primary arguments', {
4848
object = degree_scale_kernlab$method$fit$args,
4949
expected = list(
5050
x = expr(missing_arg()),
51-
y = expr(missing_arg()),
51+
data = expr(missing_arg()),
5252
kernel = "polydot",
5353
kpar = degree_scale_obj
5454
)
@@ -64,7 +64,7 @@ test_that('engine arguments', {
6464
object = translate(kernlab_cv, "kernlab")$method$fit$args,
6565
expected = list(
6666
x = expr(missing_arg()),
67-
y = expr(missing_arg()),
67+
data = expr(missing_arg()),
6868
cross = new_empty_quosure(10),
6969
kernel = "polydot"
7070
)
@@ -189,7 +189,7 @@ test_that('svm poly regression prediction', {
189189
y = iris$Sepal.Length,
190190
control = ctrl
191191
)
192-
expect_equal(reg_form$fit, reg_xy_form$fit)
192+
expect_equal(reg_form$fit@alphaindex, reg_xy_form$fit@alphaindex)
193193

194194
parsnip_xy_pred <- predict(reg_xy_form, iris[1:3, -c(1, 5)])
195195
expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_xy_pred))
@@ -260,7 +260,7 @@ test_that('svm poly classification probabilities', {
260260
y = iris$Species,
261261
control = ctrl
262262
)
263-
expect_equal(cls_form$fit, cls_xy_form$fit)
263+
expect_equal(cls_form$fit@alphaindex, cls_xy_form$fit@alphaindex)
264264

265265
library(kernlab)
266266
kern_probs <-

tests/testthat/test_svm_rbf.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ library(rlang)
55
# ------------------------------------------------------------------------------
66

77
context("poly SVM")
8-
source("helpers.R")
8+
source(test_path("helpers.R"))
99

1010
# ------------------------------------------------------------------------------
1111

@@ -17,7 +17,7 @@ test_that('primary arguments', {
1717
object = basic_kernlab$method$fit$args,
1818
expected = list(
1919
x = expr(missing_arg()),
20-
y = expr(missing_arg()),
20+
data = expr(missing_arg()),
2121
kernel = "rbfdot"
2222
)
2323
)
@@ -31,7 +31,7 @@ test_that('primary arguments', {
3131
object = rbf_sigma_kernlab$method$fit$args,
3232
expected = list(
3333
x = expr(missing_arg()),
34-
y = expr(missing_arg()),
34+
data = expr(missing_arg()),
3535
kernel = "rbfdot",
3636
kpar = rbf_sigma_obj
3737
)
@@ -47,7 +47,7 @@ test_that('engine arguments', {
4747
object = translate(kernlab_cv, "kernlab")$method$fit$args,
4848
expected = list(
4949
x = expr(missing_arg()),
50-
y = expr(missing_arg()),
50+
data = expr(missing_arg()),
5151
cross = new_empty_quosure(10),
5252
kernel = "rbfdot"
5353
)
@@ -164,7 +164,7 @@ test_that('svm rbf regression prediction', {
164164
y = iris$Sepal.Length,
165165
control = ctrl
166166
)
167-
expect_equal(reg_form$fit, reg_xy_form$fit)
167+
expect_equal(reg_form$fit@alphaindex, reg_xy_form$fit@alphaindex)
168168

169169
parsnip_xy_pred <- predict(reg_xy_form, iris[1:3, -c(1, 5)])
170170
expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_xy_pred))
@@ -235,7 +235,7 @@ test_that('svm rbf classification probabilities', {
235235
y = iris$Species,
236236
control = ctrl
237237
)
238-
expect_equal(cls_form$fit, cls_xy_form$fit)
238+
expect_equal(cls_form$fit@alphaindex, cls_xy_form$fit@alphaindex)
239239

240240
library(kernlab)
241241
kern_probs <-

0 commit comments

Comments
 (0)