Skip to content

Commit d7594f3

Browse files
authored
Merge pull request #1 from tidymodels/master
Merge from master
2 parents 670e75d + 0e83faf commit d7594f3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+483
-98
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.5.9000
2+
Version: 0.0.5.9001
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(
@@ -31,7 +31,7 @@ Imports:
3131
prettyunits,
3232
vctrs (>= 0.2.0)
3333
Roxygen: list(markdown = TRUE)
34-
RoxygenNote: 7.0.2.9000
34+
RoxygenNote: 7.1.0
3535
Suggests:
3636
testthat,
3737
knitr,

NAMESPACE

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ S3method(predict,model_spec)
2525
S3method(predict,nullmodel)
2626
S3method(predict_class,"_lognet")
2727
S3method(predict_class,"_multnet")
28+
S3method(predict_class,model_fit)
2829
S3method(predict_classprob,"_lognet")
2930
S3method(predict_classprob,"_multnet")
31+
S3method(predict_classprob,model_fit)
32+
S3method(predict_confint,model_fit)
3033
S3method(predict_numeric,"_elnet")
34+
S3method(predict_numeric,model_fit)
35+
S3method(predict_quantile,model_fit)
3136
S3method(predict_raw,"_elnet")
3237
S3method(predict_raw,"_lognet")
3338
S3method(predict_raw,"_multnet")
39+
S3method(predict_raw,model_fit)
3440
S3method(print,boost_tree)
3541
S3method(print,control_parsnip)
3642
S3method(print,decision_tree)
@@ -92,7 +98,10 @@ export(boost_tree)
9298
export(check_empty_ellipse)
9399
export(check_final_param)
94100
export(control_parsnip)
101+
export(convert_args)
102+
export(convert_stan_interval)
95103
export(decision_tree)
104+
export(eval_args)
96105
export(fit)
97106
export(fit.model_spec)
98107
export(fit_control)
@@ -122,6 +131,13 @@ export(null_value)
122131
export(nullmodel)
123132
export(pred_value_template)
124133
export(predict.model_fit)
134+
export(predict_class.model_fit)
135+
export(predict_classprob.model_fit)
136+
export(predict_confint.model_fit)
137+
export(predict_numeric)
138+
export(predict_numeric.model_fit)
139+
export(predict_quantile.model_fit)
140+
export(predict_raw.model_fit)
125141
export(rand_forest)
126142
export(rpart_train)
127143
export(set_args)
@@ -144,6 +160,7 @@ export(svm_poly)
144160
export(svm_rbf)
145161
export(tidy.model_fit)
146162
export(translate)
163+
export(translate.default)
147164
export(update_dot_check)
148165
export(update_main_parameters)
149166
export(varying)

R/aaa.R

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ maybe_multivariate <- function(results, object) {
1313
results
1414
}
1515

16+
#' Convenience function for intervals
1617
#' @importFrom stats quantile
18+
#' @export
19+
#' @keywords internal
20+
#' @param x A fitted model object
21+
#' @param level Level of uncertainty for intervals
22+
#' @param lower Is `level` the lower level?
1723
convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
1824
alpha <- (1 - level) / 2
1925
if (!lower) {
@@ -24,6 +30,33 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
2430
res
2531
}
2632

33+
#' Make a table of arguments
34+
#' @param model_name A character string for the model
35+
#' @keywords internal
36+
#' @export
37+
convert_args <- function(model_name) {
38+
envir <- get_model_env()
39+
40+
args <-
41+
ls(envir) %>%
42+
tibble::tibble(name = .) %>%
43+
dplyr::filter(grepl("args", name)) %>%
44+
dplyr::mutate(model = sub("_args", "", name),
45+
args = purrr::map(name, ~envir[[.x]])) %>%
46+
tidyr::unnest(args) %>%
47+
dplyr::select(model:original)
48+
49+
convert_df <- args %>%
50+
dplyr::filter(grepl(model_name, model)) %>%
51+
dplyr::select(-model) %>%
52+
tidyr::pivot_wider(names_from = engine, values_from = original)
53+
54+
convert_df %>%
55+
knitr::kable(col.names = paste0("**", colnames(convert_df), "**"))
56+
57+
}
58+
59+
2760
# ------------------------------------------------------------------------------
2861
# nocov
2962

@@ -32,8 +65,8 @@ utils::globalVariables(
3265
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
3366
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
3467
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
35-
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
68+
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
3669
"sub_neighbors", ".pred_class")
37-
)
70+
)
3871

3972
# nocov end

R/aaa_multi_predict.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,18 @@ multi_predict_args.default <- function(object, ...) {
117117
#' @export
118118
#' @rdname has_multi_predict
119119
multi_predict_args.model_fit <- function(object, ...) {
120-
existing_mthds <- methods("multi_predict")
121-
cls <- class(object)
122-
tst <- paste0("multi_predict.", cls)
123-
.fn <- tst[tst %in% existing_mthds]
124-
if (length(.fn) == 0) {
125-
return(NA_character_)
120+
model_type <- class(object$spec)[1]
121+
arg_info <- get_from_env(paste0(model_type, "_args"))
122+
arg_info <- arg_info[arg_info$engine == object$spec$engine,]
123+
arg_info <- arg_info[arg_info$has_submodel,]
124+
125+
if (nrow(arg_info) == 0) {
126+
res <- NA_character_
127+
} else {
128+
res <- arg_info[["parsnip"]]
126129
}
127130

128-
.fn <- getFromNamespace(.fn, ns = "parsnip")
129-
omit <- c('object', 'new_data', 'type', '...')
130-
args <- names(formals(.fn))
131-
args[!(args %in% omit)]
131+
res
132132
}
133133

134134
#' @export

R/arguments.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ maybe_eval <- function(x) {
103103
y
104104
}
105105

106+
#' Evaluate parsnip model arguments
107+
#' @export
108+
#' @keywords internal
109+
#' @param spec A model specification
110+
#' @param ... Not used.
106111
eval_args <- function(spec, ...) {
107112
spec$args <- purrr::map(spec$args, maybe_eval)
108113
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)

R/boost_tree.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@
6161
#'
6262
#' @section Engine Details:
6363
#'
64+
#' The standardized parameter names in parsnip can be mapped to their original
65+
#' names in each engine:
66+
#'
67+
#' ```{r echo = FALSE}
68+
#' convert_args("boost_tree")
69+
#' ```
70+
#'
6471
#' Engines may have pre-set default arguments when executing the
6572
#' model fit call. For this type of model, the template of the
6673
#' fit calls are:

R/decision_tree.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@
4646
#'
4747
#' @section Engine Details:
4848
#'
49+
#' The standardized parameter names in parsnip can be mapped to their original
50+
#' names in each engine:
51+
#'
52+
#' ```{r echo = FALSE}
53+
#' convert_args("decision_tree")
54+
#' ```
55+
#'
4956
#' Engines may have pre-set default arguments when executing the
5057
#' model fit call. For this type of
5158
#' model, the template of the fit calls are::

R/linear_reg.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@
4444
#'
4545
#' @section Engine Details:
4646
#'
47+
#' The standardized parameter names in parsnip can be mapped to their original
48+
#' names in each engine:
49+
#'
50+
#' ```{r echo = FALSE}
51+
#' convert_args("linear_reg")
52+
#' ```
53+
#'
4754
#' Engines may have pre-set default arguments when executing the
4855
#' model fit call. For this type of
4956
#' model, the template of the fit calls are:
@@ -60,6 +67,8 @@
6067
#'
6168
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")}
6269
#'
70+
#' (note that the `refresh` default prevents logging of the estimation process. Change this value in `set_engine()` will show the logs)
71+
#'
6372
#' \pkg{spark}
6473
#'
6574
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
@@ -98,7 +107,7 @@
98107
#' separately saved to disk. In a new session, the object can be
99108
#' reloaded and reattached to the `parsnip` object.
100109
#'
101-
#' @seealso [[fit()], [set_engine()]
110+
#' @seealso [fit()], [set_engine()]
102111
#' @examples
103112
#' linear_reg()
104113
#' # Parameters can be represented by a placeholder:

R/linear_reg_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ set_fit(
179179
interface = "formula",
180180
protect = c("formula", "data", "weights"),
181181
func = c(pkg = "rstanarm", fun = "stan_glm"),
182-
defaults = list(family = expr(stats::gaussian))
182+
defaults = list(family = expr(stats::gaussian), refresh = 0)
183183
)
184184
)
185185

@@ -293,7 +293,7 @@ set_model_arg(
293293
parsnip = "penalty",
294294
original = "reg_param",
295295
func = list(pkg = "dials", fun = "penalty"),
296-
has_submodel = TRUE
296+
has_submodel = FALSE
297297
)
298298

299299
set_model_arg(

R/logistic_reg.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
#'
4343
#' @section Engine Details:
4444
#'
45+
#' The standardized parameter names in parsnip can be mapped to their original
46+
#' names in each engine:
47+
#'
48+
#' ```{r echo = FALSE}
49+
#' convert_args("logistic_reg")
50+
#' ```
51+
#'
4552
#' Engines may have pre-set default arguments when executing the
4653
#' model fit call. For this type of
4754
#' model, the template of the fit calls are:
@@ -58,6 +65,8 @@
5865
#'
5966
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")}
6067
#'
68+
#' (note that the `refresh` default prevents logging of the estimation process. Change this value in `set_engine()` will show the logs)
69+
#'
6170
#' \pkg{spark}
6271
#'
6372
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")}
@@ -97,7 +106,7 @@
97106
#' separately saved to disk. In a new session, the object can be
98107
#' reloaded and reattached to the `parsnip` object.
99108
#'
100-
#' @seealso [[fit()]
109+
#' @seealso [fit()]
101110
#' @examples
102111
#' logistic_reg()
103112
#' # Parameters can be represented by a placeholder:

R/logistic_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ set_fit(
359359
interface = "formula",
360360
protect = c("formula", "data", "weights"),
361361
func = c(pkg = "rstanarm", fun = "stan_glm"),
362-
defaults = list(family = expr(stats::binomial))
362+
defaults = list(family = expr(stats::binomial), refresh = 0)
363363
)
364364
)
365365

R/mars.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
#'
3939
#' @section Engine Details:
4040
#'
41+
#' The standardized parameter names in parsnip can be mapped to their original
42+
#' names in each engine:
43+
#'
44+
#' ```{r echo = FALSE}
45+
#' convert_args("mars")
46+
#' ```
47+
#'
4148
#' Engines may have pre-set default arguments when executing the
4249
#' model fit call. For this type of
4350
#' model, the template of the fit calls are:
@@ -55,7 +62,7 @@
5562
#' attached.
5663
#'
5764
#' @importFrom purrr map_lgl
58-
#' @seealso [[fit()]
65+
#' @seealso [fit()]
5966
#' @examples
6067
#' mars(mode = "regression", num_terms = 5)
6168
#' @export

R/mlp.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
#'
5757
#' @section Engine Details:
5858
#'
59+
#' The standardized parameter names in parsnip can be mapped to their original
60+
#' names in each engine:
61+
#'
62+
#' ```{r echo = FALSE}
63+
#' convert_args("mlp")
64+
#' ```
65+
#'
5966
#' Engines may have pre-set default arguments when executing the
6067
#' model fit call. For this type of
6168
#' model, the template of the fit calls are:
@@ -77,7 +84,7 @@
7784
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")}
7885
#'
7986
#' @importFrom purrr map_lgl
80-
#' @seealso [[fit()]
87+
#' @seealso [fit()]
8188
#' @examples
8289
#' mlp(mode = "classification", penalty = 0.01)
8390
#' # Parameters can be represented by a placeholder:

R/multinom_reg.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
#'
4242
#' @section Engine Details:
4343
#'
44+
#' The standardized parameter names in parsnip can be mapped to their original
45+
#' names in each engine:
46+
#'
47+
#' ```{r echo = FALSE}
48+
#' convert_args("multinom_reg")
49+
#' ```
50+
#'
4451
#' Engines may have pre-set default arguments when executing the
4552
#' model fit call. For this type of
4653
#' model, the template of the fit calls are:
@@ -84,7 +91,7 @@
8491
#' separately saved to disk. In a new session, the object can be
8592
#' reloaded and reattached to the `parsnip` object.
8693
#'
87-
#' @seealso [[fit()]
94+
#' @seealso [fit()]
8895
#' @examples
8996
#' multinom_reg()
9097
#' # Parameters can be represented by a placeholder:

R/nearest_neighbor.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
#'
4949
#' @section Engine Details:
5050
#'
51+
#' The standardized parameter names in parsnip can be mapped to their original
52+
#' names in each engine:
53+
#'
54+
#' ```{r echo = FALSE}
55+
#' convert_args("nearest_neighbor")
56+
#' ```
57+
#'
5158
#' Engines may have pre-set default arguments when executing the
5259
#' model fit call. For this type of
5360
#' model, the template of the fit calls are:
@@ -63,7 +70,7 @@
6370
#' on new data. This also means that a single value of that function's
6471
#' `kernel` argument (a.k.a `weight_func` here) can be supplied
6572
#'
66-
#' @seealso [[fit()]
73+
#' @seealso [fit()]
6774
#'
6875
#' @examples
6976
#' nearest_neighbor(neighbors = 11)

0 commit comments

Comments
 (0)