Skip to content

Commit cc3f892

Browse files
committed
Merge branch 'master' into try-catch-augment
2 parents b2b054d + f364b81 commit cc3f892

36 files changed

+326
-152
lines changed

NEWS.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
# parsnip (development version)
22

3-
* `generics::required_pkgs()` was extended for `parsnip` objects.
4-
5-
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
3+
## Model Specification Changes
64

75
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424) and the `kernlab` engine (#438), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429). These models can use sparse matrices via `fit_xy()` (#447) and have a `tidy` method (#474).
86

7+
* For models with `glmnet` engines:
8+
9+
- A single value is required for `penalty` (either a single numeric value or a value of `tune()`) (#481).
10+
- A special argument called `path_values` can be used to set the `lambda` path as a specific set of numbers (independent of the value of `penalty`). A pure ridge regression models (i.e., `mixture = 1`) will generate incorrect values if the path does not include zero. See issue #431 for discussion (#486).
11+
12+
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
13+
914
* New model specification `survival_reg()` for the new mode `"censored regression"` (#444). `surv_reg()` is now soft-deprecated (#448).
1015

1116
* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451).
1217

18+
## Other Changes
19+
1320
* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).
1421

1522
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467).
1623

1724
* Re-organized model documentation for `update` methods (#479).
1825

26+
27+
28+
* `generics::required_pkgs()` was extended for `parsnip` objects.
29+
30+
1931

2032
# parsnip 0.1.5
2133

R/aaa_models.R

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,27 @@ check_model_doesnt_exist <- function(model) {
127127
}
128128

129129
check_mode_val <- function(mode) {
130-
if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode))
130+
if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) {
131131
rlang::abort("Please supply a character string for a mode (e.g. `'regression'`).")
132+
}
132133
invisible(NULL)
133134
}
134135

135136
# check if class and mode are compatible
136137
check_spec_mode_val <- function(cls, mode) {
137138
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
138-
if (!(mode %in% spec_modes))
139-
rlang::abort(
140-
glue::glue(
141-
"`mode` should be one of: ",
142-
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
143-
)
139+
compatible_modes <-
140+
glue::glue(
141+
"`mode` should be one of: ",
142+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
144143
)
144+
145+
if (is.null(mode)) {
146+
rlang::abort(compatible_modes)
147+
} else if (!(mode %in% spec_modes)) {
148+
rlang::abort(compatible_modes)
149+
}
150+
145151
invisible(NULL)
146152
}
147153

@@ -280,12 +286,27 @@ check_pred_info <- function(pred_obj, type) {
280286
invisible(NULL)
281287
}
282288

289+
check_spec_pred_type <- function(object, type) {
290+
possible_preds <- names(object$spec$method$pred)
291+
if (!any(possible_preds == type)) {
292+
rlang::abort(c(
293+
glue::glue("No {type} prediction method available for this model."),
294+
glue::glue("Value for `type` should be one of: ",
295+
glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", "))
296+
))
297+
}
298+
invisible(NULL)
299+
}
300+
301+
283302
check_pkg_val <- function(pkg) {
284-
if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg))
303+
if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) {
285304
rlang::abort("Please supply a single character value for the package name.")
305+
}
286306
invisible(NULL)
287307
}
288308

309+
289310
check_interface_val <- function(x) {
290311
exp_interf <- c("data.frame", "formula", "matrix")
291312
if (length(x) != 1 || !(x %in% exp_interf)) {

R/arguments.R

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,10 @@ set_args <- function(object, ...) {
7676
#' @rdname set_args
7777
#' @export
7878
set_mode <- function(object, mode) {
79-
if (is.null(mode))
80-
return(object)
81-
mode <- mode[1]
82-
if (!(any(all_modes == mode))) {
83-
rlang::abort(
84-
glue::glue(
85-
"`mode` should be one of ",
86-
glue::glue_collapse(glue::glue("'{all_modes}'"), sep = ", ")
87-
)
88-
)
79+
if (rlang::is_missing(mode)) {
80+
mode <- NULL
8981
}
82+
mode <- mode[1]
9083
check_spec_mode_val(class(object)[1], mode)
9184
object$mode <- mode
9285
object

R/engines.R

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ check_engine <- function(object) {
1515
if (is.null(object$engine)) {
1616
object$engine <- avail_eng[1]
1717
rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`"))
18+
} else {
19+
if (!is.character(object$engine) | length(object$engine) != 1) {
20+
rlang::abort("`engine` should be a single character value.")
21+
}
1822
}
1923
if (!(object$engine %in% avail_eng)) {
2024
rlang::abort(
@@ -91,18 +95,20 @@ set_engine <- function(object, engine, ...) {
9195
if (!inherits(object, "model_spec")) {
9296
rlang::abort("`object` should have class 'model_spec'.")
9397
}
94-
if (!is.character(engine) | length(engine) != 1)
95-
rlang::abort("`engine` should be a single character value.")
96-
if (engine == "liquidSVM") {
98+
99+
if (rlang::is_missing(engine)) {
100+
engine <- NULL
101+
}
102+
object$engine <- engine
103+
object <- check_engine(object)
104+
105+
if (object$engine == "liquidSVM") {
97106
lifecycle::deprecate_soft(
98107
"0.1.6",
99108
"set_engine(engine = 'cannot be liquidSVM')",
100109
details = "The liquidSVM package is no longer available on CRAN.")
101110
}
102111

103-
object$engine <- engine
104-
object <- check_engine(object)
105-
106112
new_model_spec(
107113
cls = class(object)[1],
108114
args = object$args,

R/linear_reg.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,23 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
107107
x <- translate.default(x, engine, ...)
108108

109109
if (engine == "glmnet") {
110-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
111-
x$method$fit$args$lambda <- NULL
110+
check_glmnet_penalty(x)
111+
if (any(names(x$eng_args) == "path_values")) {
112+
# Since we decouple the parsnip `penalty` argument from being the same
113+
# as the glmnet `lambda` value, `path_values` allows users to set the
114+
# path differently from the default that glmnet uses. See
115+
# https://github.com/tidymodels/parsnip/issues/431
116+
x$method$fit$args$lambda <- x$eng_args$path_values
117+
x$eng_args$path_values <- NULL
118+
x$method$fit$args$path_values <- NULL
119+
} else {
120+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
121+
x$method$fit$args$lambda <- NULL
122+
}
112123
# Since the `fit` information is gone for the penalty, we need to have an
113124
# evaluated value for the parameter.
114125
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
115-
check_glmnet_penalty(x)
116126
}
117-
118127
x
119128
}
120129

R/logistic_reg.R

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,23 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
108108
arg_vals <- x$method$fit$args
109109
arg_names <- names(arg_vals)
110110

111-
112111
if (engine == "glmnet") {
113-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
114-
arg_vals$lambda <- NULL
112+
check_glmnet_penalty(x)
113+
if (any(names(x$eng_args) == "path_values")) {
114+
# Since we decouple the parsnip `penalty` argument from being the same
115+
# as the glmnet `lambda` value, `path_values` allows users to set the
116+
# path differently from the default that glmnet uses. See
117+
# https://github.com/tidymodels/parsnip/issues/431
118+
x$method$fit$args$lambda <- x$eng_args$path_values
119+
x$eng_args$path_values <- NULL
120+
x$method$fit$args$path_values <- NULL
121+
} else {
122+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
123+
x$method$fit$args$lambda <- NULL
124+
}
115125
# Since the `fit` information is gone for the penalty, we need to have an
116126
# evaluated value for the parameter.
117127
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
118-
check_glmnet_penalty(x)
119128
}
120129

121130
if (engine == "LiblineaR") {
@@ -134,11 +143,8 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
134143
rlang::abort("For the LiblineaR engine, mixture must be 0 or 1.")
135144
}
136145
}
137-
146+
x$method$fit$args <- arg_vals
138147
}
139-
140-
x$method$fit$args <- arg_vals
141-
142148
x
143149
}
144150

R/misc.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,12 @@ stan_conf_int <- function(object, newdata) {
324324
}
325325

326326
check_glmnet_penalty <- function(x) {
327-
if (length(x$args$penalty) != 1) {
327+
pen <- rlang::eval_tidy(x$args$penalty)
328+
329+
if (length(pen) != 1) {
328330
rlang::abort(c(
329331
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
330-
glue::glue("There are {length(x$args$penalty)} values for `penalty`."),
332+
glue::glue("There are {length(pen)} values for `penalty`."),
331333
"To try multiple values for total regularization, use the tune package.",
332334
"To predict multiple penalties, use `multi_predict()`"
333335
))

R/predict_class.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ predict_class.model_fit <- function(object, new_data, ...) {
1212
if (object$spec$mode != "classification")
1313
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1414

15-
if (!any(names(object$spec$method$pred) == "class"))
16-
rlang::abort("No class prediction module defined for this model.")
15+
check_spec_pred_type(object, "class")
1716

1817
if (inherits(object$fit, "try-error")) {
1918
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_classprob.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
99
if (object$spec$mode != "classification")
1010
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1111

12-
if (!any(names(object$spec$method$pred) == "prob"))
13-
rlang::abort("No class probability module defined for this model.")
12+
check_spec_pred_type(object, "prob")
13+
1414

1515
if (inherits(object$fit, "try-error")) {
1616
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_hazard.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_hazard.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$hazard))
11-
rlang::abort("No hazard prediction method defined for this engine.")
10+
check_spec_pred_type(object, "hazard")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_interval.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#' @export
1111
predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
1212

13-
if (is.null(object$spec$method$pred$conf_int))
14-
rlang::abort("No confidence interval method defined for this engine.")
13+
check_spec_pred_type(object, "conf_int")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")
@@ -58,8 +57,7 @@ predict_confint <- function(object, ...)
5857
# @export
5958
predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
6059

61-
if (is.null(object$spec$method$pred$pred_int))
62-
rlang::abort("No prediction interval method defined for this engine.")
60+
check_spec_pred_type(object, "pred_int")
6361

6462
if (inherits(object$fit, "try-error")) {
6563
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_linear_pred.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#' @export
77
predict_linear_pred.model_fit <- function(object, new_data, ...) {
88

9-
if (!any(names(object$spec$method$pred) == "linear_pred"))
10-
rlang::abort("No prediction module defined for this model.")
9+
check_spec_pred_type(object, "linear_pred")
1110

1211
if (inherits(object$fit, "try-error")) {
1312
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_numeric.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "numeric"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "numeric")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_quantile.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
predict_quantile.model_fit <-
1010
function(object, new_data, quantile = (1:9)/10, ...) {
1111

12-
if (is.null(object$spec$method$pred$quantile))
13-
rlang::abort("No quantile prediction method defined for this engine.")
12+
check_spec_pred_type(object, "quantile")
1413

1514
if (inherits(object$fit, "try-error")) {
1615
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_raw.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) {
1313
c(object$spec$method$pred$raw$args, opts)
1414
}
1515

16-
if (!any(names(object$spec$method$pred) == "raw"))
17-
rlang::abort("No raw prediction module defined for this model.")
16+
check_spec_pred_type(object, "raw")
1817

1918
if (inherits(object$fit, "try-error")) {
2019
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_survival.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_survival.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$survival))
11-
rlang::abort("No survival prediction method defined for this engine.")
10+
check_spec_pred_type(object, "survival")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_time.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_time.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "time"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "time")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

R/svm_linear_data.R

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,6 @@ set_pred(
123123
)
124124
)
125125
)
126-
set_pred(
127-
model = "svm_linear",
128-
eng = "LiblineaR",
129-
mode = "classification",
130-
type = "prob",
131-
value = list(
132-
pre = function(x, object) {
133-
rlang::abort(
134-
paste0("The LiblineaR engine does not support class probabilities ",
135-
"for any `svm` models.")
136-
)
137-
},
138-
post = NULL,
139-
func = c(fun = "predict"),
140-
args =
141-
list(
142-
object = quote(object$fit),
143-
newx = expr(as.matrix(new_data))
144-
)
145-
)
146-
)
147126
set_pred(
148127
model = "svm_linear",
149128
eng = "LiblineaR",

0 commit comments

Comments
 (0)