Skip to content

Commit fbd544a

Browse files
committed
_almost_ all tests pass
1 parent 409fdf1 commit fbd544a

26 files changed

+192
-184
lines changed

R/aaa.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
1818
res
1919
}
2020

21+
# ------------------------------------------------------------------------------
22+
23+
#' @importFrom utils globalVariables
24+
utils::globalVariables(c("value", "engine", "lab", "original", "engine2"))

R/boost_tree.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,15 +387,15 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
387387
pred <- xgb_pred(object$fit, newdata = new_data, ntreelimit = tree)
388388

389389
# switch based on prediction type
390-
if(object$spec$mode == "regression") {
390+
if (object$spec$mode == "regression") {
391391
pred <- tibble(.pred = pred)
392392
nms <- names(pred)
393393
} else {
394394
if (type == "class") {
395-
pred <- boost_tree_xgboost_data$class$post(pred, object)
395+
pred <- object$spec$method$pred$class$post(pred, object)
396396
pred <- tibble(.pred = factor(pred, levels = object$lvl))
397397
} else {
398-
pred <- boost_tree_xgboost_data$classprob$post(pred, object)
398+
pred <- object$spec$method$pred$prob$post(pred, object)
399399
pred <- as_tibble(pred)
400400
names(pred) <- paste0(".pred_", names(pred))
401401
}

R/decision_tree_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ set_model_arg(
141141
mod = "decision_tree",
142142
eng = "C5.0",
143143
val = "min_n",
144-
original = "minsplit",
144+
original = "minCases",
145145
func = list(pkg = "dials", fun = "min_n"),
146146
submodels = FALSE
147147
)

R/linear_reg_data.R

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,16 @@ set_pred(
209209
.pred_lower =
210210
convert_stan_interval(
211211
results,
212-
level = object$spec$method$confint$extras$level
212+
level = object$spec$method$pred$conf_int$extras$level
213213
),
214214
.pred_upper =
215215
convert_stan_interval(
216216
results,
217-
level = object$spec$method$confint$extras$level,
217+
level = object$spec$method$pred$conf_int$extras$level,
218218
lower = FALSE
219219
),
220220
)
221-
if(object$spec$method$confint$extras$std_error)
221+
if (object$spec$method$pred$conf_int$extras$std_error)
222222
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
223223
res
224224
},
@@ -246,16 +246,16 @@ set_pred(
246246
.pred_lower =
247247
convert_stan_interval(
248248
results,
249-
level = object$spec$method$predint$extras$level
249+
level = object$spec$method$pred$pred_int$extras$level
250250
),
251251
.pred_upper =
252252
convert_stan_interval(
253253
results,
254-
level = object$spec$method$predint$extras$level,
254+
level = object$spec$method$pred$pred_int$extras$level,
255255
lower = FALSE
256256
),
257257
)
258-
if(object$spec$method$predint$extras$std_error)
258+
if (object$spec$method$pred$pred_int$extras$std_error)
259259
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
260260
res
261261
},
@@ -287,6 +287,25 @@ set_pred(
287287
set_model_engine("linear_reg", "regression", "spark")
288288
set_dependency("linear_reg", "spark", "sparklyr")
289289

290+
set_model_arg(
291+
mod = "linear_reg",
292+
eng = "spark",
293+
val = "penalty",
294+
original = "reg_param",
295+
func = list(pkg = "dials", fun = "penalty"),
296+
submodels = TRUE
297+
)
298+
299+
set_model_arg(
300+
mod = "linear_reg",
301+
eng = "spark",
302+
val = "mixture",
303+
original = "elastic_net_param",
304+
func = list(pkg = "dials", fun = "mixture"),
305+
submodels = FALSE
306+
)
307+
308+
290309
set_fit(
291310
mod = "linear_reg",
292311
eng = "spark",

R/logistic_reg_data.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ set_pred(
8484
value = list(
8585
pre = NULL,
8686
post = function(results, object) {
87-
hf_lvl <- (1 - object$spec$method$confint$extras$level)/2
87+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
8888
const <-
8989
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
9090
trans <- object$fit$family$linkinv
@@ -101,7 +101,7 @@ set_pred(
101101
hi_nms <- paste0(".pred_upper_", object$lvl)
102102
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
103103

104-
if (object$spec$method$confint$extras$std_error)
104+
if (object$spec$method$pred$conf_int$extras$std_error)
105105
res$.std_error <- results$se.fit
106106
res
107107
},
@@ -214,7 +214,7 @@ set_dependency("logistic_reg", "spark", "sparklyr")
214214

215215
set_model_arg(
216216
mod = "logistic_reg",
217-
eng = "glmnet",
217+
eng = "spark",
218218
val = "penalty",
219219
original = "reg_param",
220220
func = list(pkg = "dials", fun = "penalty"),
@@ -223,9 +223,9 @@ set_model_arg(
223223

224224
set_model_arg(
225225
mod = "logistic_reg",
226-
eng = "glmnet",
227-
val = "elastic_net_param",
228-
original = "alpha",
226+
eng = "spark",
227+
val = "mixture",
228+
original = "elastic_net_param",
229229
func = list(pkg = "dials", fun = "mixture"),
230230
submodels = FALSE
231231
)
@@ -439,12 +439,12 @@ set_pred(
439439
lo =
440440
convert_stan_interval(
441441
results,
442-
level = object$spec$method$confint$extras$level
442+
level = object$spec$method$pred$conf_int$extras$level
443443
),
444444
hi =
445445
convert_stan_interval(
446446
results,
447-
level = object$spec$method$confint$extras$level,
447+
level = object$spec$method$pred$conf_int$extras$level,
448448
lower = FALSE
449449
),
450450
)
@@ -456,7 +456,7 @@ set_pred(
456456
hi_nms <- paste0(".pred_upper_", object$lvl)
457457
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
458458

459-
if (object$spec$method$confint$extras$std_error)
459+
if (object$spec$method$pred$conf_int$extras$std_error)
460460
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
461461
res
462462
},
@@ -484,12 +484,12 @@ set_pred(
484484
lo =
485485
convert_stan_interval(
486486
results,
487-
level = object$spec$method$predint$extras$level
487+
level = object$spec$method$pred$pred_int$extras$level
488488
),
489489
hi =
490490
convert_stan_interval(
491491
results,
492-
level = object$spec$method$predint$extras$level,
492+
level = object$spec$method$pred$pred_int$extras$level,
493493
lower = FALSE
494494
),
495495
)
@@ -501,7 +501,7 @@ set_pred(
501501
hi_nms <- paste0(".pred_upper_", object$lvl)
502502
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
503503

504-
if (object$spec$method$predint$extras$std_error)
504+
if (object$spec$method$pred$pred_int$extras$std_error)
505505
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
506506
res
507507
},

R/mlp_data.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,18 @@ set_model_arg(
187187
mod = "mlp",
188188
eng = "nnet",
189189
val = "penalty",
190-
original = "penalty",
190+
original = "decay",
191191
func = list(pkg = "dials", fun = "weight_decay"),
192192
submodels = FALSE
193193
)
194-
194+
set_model_arg(
195+
mod = "mlp",
196+
eng = "nnet",
197+
val = "epochs",
198+
original = "maxit",
199+
func = list(pkg = "dials", fun = "epochs"),
200+
submodels = FALSE
201+
)
195202
set_fit(
196203
mod = "mlp",
197204
eng = "nnet",

R/multinom_reg_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ set_dependency("multinom_reg", "spark", "sparklyr")
100100

101101
set_model_arg(
102102
mod = "multinom_reg",
103-
eng = "glmnet",
103+
eng = "spark",
104104
val = "penalty",
105105
original = "reg_param",
106106
func = list(pkg = "dials", fun = "penalty"),
@@ -109,9 +109,9 @@ set_model_arg(
109109

110110
set_model_arg(
111111
mod = "multinom_reg",
112-
eng = "glmnet",
113-
val = "elastic_net_param",
114-
original = "alpha",
112+
eng = "spark",
113+
val = "mixture",
114+
original = "elastic_net_param",
115115
func = list(pkg = "dials", fun = "mixture"),
116116
submodels = FALSE
117117
)

R/nearest_neighbor_data.R

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@ set_model_engine("nearest_neighbor", "classification", "kknn")
1010
set_model_engine("nearest_neighbor", "regression", "kknn")
1111
set_dependency("nearest_neighbor", "kknn", "kknn")
1212

13-
set_model_arg(
14-
mod = "nearest_neighbor",
15-
eng = "kknn",
16-
val = "num_terms",
17-
original = "nprune",
18-
func = list(pkg = "dials", fun = "num_terms"),
19-
submodels = FALSE
20-
)
2113
set_model_arg(
2214
mod = "nearest_neighbor",
2315
eng = "kknn",
@@ -37,8 +29,8 @@ set_model_arg(
3729
set_model_arg(
3830
mod = "nearest_neighbor",
3931
eng = "kknn",
40-
val = "distance",
41-
original = "dist_power",
32+
val = "dist_power",
33+
original = "distance",
4234
func = list(pkg = "dials", fun = "distance"),
4335
submodels = FALSE
4436
)
@@ -49,7 +41,7 @@ set_fit(
4941
mode = "regression",
5042
value = list(
5143
interface = "formula",
52-
protect = c("formula", "data", "kmax"), # kmax is not allowed
44+
protect = c("formula", "data", "ks"),
5345
func = c(pkg = "kknn", fun = "train.kknn"),
5446
defaults = list()
5547
)
@@ -61,7 +53,7 @@ set_fit(
6153
mode = "classification",
6254
value = list(
6355
interface = "formula",
64-
protect = c("formula", "data", "kmax"), # kmax is not allowed
56+
protect = c("formula", "data", "ks"),
6557
func = c(pkg = "kknn", fun = "train.kknn"),
6658
defaults = list()
6759
)

R/nullmodel.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) {
160160
#' @export
161161
null_model <-
162162
function(mode = "classification") {
163+
null_model_modes <- unique(get_model_env()$null_model$mode)
163164
# Check for correct mode
164165
if (!(mode %in% null_model_modes))
165166
stop("`mode` should be one of: ",

R/predict_class.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ predict_class.model_fit <- function(object, new_data, ...) {
1313
stop("`predict.model_fit()` is for predicting factor outcomes.",
1414
call. = FALSE)
1515

16-
if (!any(names(object$spec$method) == "class"))
16+
if (!any(names(object$spec$method$pred) == "class"))
1717
stop("No class prediction module defined for this model.", call. = FALSE)
1818

1919
if (inherits(object$fit, "try-error")) {
@@ -24,17 +24,17 @@ predict_class.model_fit <- function(object, new_data, ...) {
2424
new_data <- prepare_data(object, new_data)
2525

2626
# preprocess data
27-
if (!is.null(object$spec$method$class$pre))
28-
new_data <- object$spec$method$class$pre(new_data, object)
27+
if (!is.null(object$spec$method$pred$class$pre))
28+
new_data <- object$spec$method$pred$class$pre(new_data, object)
2929

3030
# create prediction call
31-
pred_call <- make_pred_call(object$spec$method$class)
31+
pred_call <- make_pred_call(object$spec$method$pred$class)
3232

3333
res <- eval_tidy(pred_call)
3434

3535
# post-process the predictions
36-
if (!is.null(object$spec$method$class$post)) {
37-
res <- object$spec$method$class$post(res, object)
36+
if (!is.null(object$spec$method$pred$class$post)) {
37+
res <- object$spec$method$pred$class$post(res, object)
3838
}
3939

4040
# coerce levels to those in `object`

R/predict_classprob.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
1010
stop("`predict.model_fit()` is for predicting factor outcomes.",
1111
call. = FALSE)
1212

13-
if (!any(names(object$spec$method) == "classprob"))
13+
if (!any(names(object$spec$method$pred) == "prob"))
1414
stop("No class probability module defined for this model.", call. = FALSE)
1515

1616
if (inherits(object$fit, "try-error")) {
@@ -21,17 +21,17 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
2121
new_data <- prepare_data(object, new_data)
2222

2323
# preprocess data
24-
if (!is.null(object$spec$method$classprob$pre))
25-
new_data <- object$spec$method$classprob$pre(new_data, object)
24+
if (!is.null(object$spec$method$pred$prob$pre))
25+
new_data <- object$spec$method$pred$prob$pre(new_data, object)
2626

2727
# create prediction call
28-
pred_call <- make_pred_call(object$spec$method$classprob)
28+
pred_call <- make_pred_call(object$spec$method$pred$prob)
2929

3030
res <- eval_tidy(pred_call)
3131

3232
# post-process the predictions
33-
if (!is.null(object$spec$method$classprob$post)) {
34-
res <- object$spec$method$classprob$post(res, object)
33+
if (!is.null(object$spec$method$pred$prob$post)) {
34+
res <- object$spec$method$pred$prob$post(res, object)
3535
}
3636

3737
# check and sort names

0 commit comments

Comments
 (0)