Skip to content

Commit 4130745

Browse files
committed
changed some argument names to be a little more specific
1 parent fc6ccbe commit 4130745

16 files changed

+233
-167
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ export(multinom_reg)
110110
export(nearest_neighbor)
111111
export(null_model)
112112
export(nullmodel)
113+
export(pred_types)
113114
export(predict.model_fit)
114115
export(rand_forest)
115116
export(rpart_train)

R/aaa_models.R

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ parsnip$modes <- c("regression", "classification", "unknown")
3030

3131
# ------------------------------------------------------------------------------
3232

33+
#' @rdname check_mod_val
34+
#' @keywords internal
35+
#' @export
3336
pred_types <-
3437
c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile")
3538

@@ -62,6 +65,30 @@ get_model_env <- function() {
6265
#' @param mode A single character string for the model mode (e.g. "regression").
6366
#' @param eng A single character string for the model engine.
6467
#' @param arg A single character string for the model argument name.
68+
#' @param has_submodel A single logical for whether the argument
69+
#' can make predictions on mutiple submodels at once.
70+
#' @param func A named character vector that describes how to call
71+
#' a function. `func` should have elements `pkg` and `fun`. The
72+
#' former is optional but is recommended and the latter is
73+
#' required. For example, `c(pkg = "stats", fun = "lm")` would be
74+
#' used to invoke the usual linear regression function. In some
75+
#' cases, it is helpful to use `c(fun = "predict")` when using a
76+
#' package's `predict` method.
77+
#' @param fit_obj A list with elements `interface`, `protect`,
78+
#' `func` and `defaults`. See the package vignette "Making a
79+
#' `parsnip` model from scratch".
80+
#' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`.
81+
#' See the package vignette "Making a `parsnip` model from scratch".
82+
#' @param type A single character value for the type of prediction. Possible
83+
#' values are:
84+
#' \Sexpr[results=rd]{paste0("'", parsnip::pred_types, "'", collapse = ", ")}.
85+
#' @param pkg An options character string for a package name.
86+
#' @param parsnip A single character string for the "harmonized" argument name
87+
#' that `parsnip` exposes.
88+
#' @param original A single character string for the argument name that
89+
#' underlying model function uses.
90+
#' @param value A list that conforms to the `fit_obj` or `pred_obj` description
91+
#' above, depending on context.
6592
#' @keywords internal
6693
#' @export
6794
check_mod_val <- function(model, new = FALSE, existence = FALSE) {
@@ -122,8 +149,8 @@ check_arg_val <- function(arg) {
122149
#' @rdname check_mod_val
123150
#' @keywords internal
124151
#' @export
125-
check_submodels_val <- function(x) {
126-
if (!is.logical(x) || length(x) != 1) {
152+
check_submodels_val <- function(has_submodel) {
153+
if (!is.logical(has_submodel) || length(has_submodel) != 1) {
127154
stop("The `submodels` argument should be a single logical.", call. = FALSE)
128155
}
129156
invisible(NULL)
@@ -169,31 +196,31 @@ check_func_val <- function(func) {
169196
#' @rdname check_mod_val
170197
#' @keywords internal
171198
#' @export
172-
check_fit_info <- function(x) {
173-
if (is.null(x)) {
199+
check_fit_info <- function(fit_obj) {
200+
if (is.null(fit_obj)) {
174201
stop("The `fit` module cannot be NULL.", call. = FALSE)
175202
}
176203
exp_nms <- c("defaults", "func", "interface", "protect")
177-
if (!isTRUE(all.equal(sort(names(x)), exp_nms))) {
204+
if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) {
178205
stop("The `fit` module should have elements: ",
179206
paste0("`", exp_nms, "`", collapse = ", "),
180207
call. = FALSE)
181208
}
182209

183210
exp_interf <- c("data.frame", "formula", "matrix")
184-
if (length(x$interface) > 1) {
211+
if (length(fit_obj$interface) > 1) {
185212
stop("The `interface` element should have a single value of : ",
186213
paste0("`", exp_interf, "`", collapse = ", "),
187214
call. = FALSE)
188215
}
189-
if (!any(x$interface == exp_interf)) {
216+
if (!any(fit_obj$interface == exp_interf)) {
190217
stop("The `interface` element should have a value of : ",
191218
paste0("`", exp_interf, "`", collapse = ", "),
192219
call. = FALSE)
193220
}
194-
check_func_val(x$func)
221+
check_func_val(fit_obj$func)
195222

196-
if (!is.list(x$defaults)) {
223+
if (!is.list(fit_obj$defaults)) {
197224
stop("The `defaults` element should be a list: ", call. = FALSE)
198225
}
199226

@@ -203,32 +230,32 @@ check_fit_info <- function(x) {
203230
#' @rdname check_mod_val
204231
#' @keywords internal
205232
#' @export
206-
check_pred_info <- function(x, type) {
233+
check_pred_info <- function(pred_obj, type) {
207234
if (all(type != pred_types)) {
208235
stop("The prediction type should be one of: ",
209236
paste0("'", pred_types, "'", collapse = ", "),
210237
call. = FALSE)
211238
}
212239

213240
exp_nms <- c("args", "func", "post", "pre")
214-
if (!isTRUE(all.equal(sort(names(x)), exp_nms))) {
241+
if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) {
215242
stop("The `predict` module should have elements: ",
216243
paste0("`", exp_nms, "`", collapse = ", "),
217244
call. = FALSE)
218245
}
219246

220-
if (!is.null(x$pre) & !is.function(x$pre)) {
247+
if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) {
221248
stop("The `pre` module should be null or a function: ",
222249
call. = FALSE)
223250
}
224-
if (!is.null(x$post) & !is.function(x$post)) {
251+
if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) {
225252
stop("The `post` module should be null or a function: ",
226253
call. = FALSE)
227254
}
228255

229-
check_func_val(x$func)
256+
check_func_val(pred_obj$func)
230257

231-
if (!is.list(x$args)) {
258+
if (!is.list(pred_obj$args)) {
232259
stop("The `args` element should be a list. ", call. = FALSE)
233260
}
234261

@@ -238,8 +265,8 @@ check_pred_info <- function(x, type) {
238265
#' @rdname check_mod_val
239266
#' @keywords internal
240267
#' @export
241-
check_pkg_val <- function(x) {
242-
if (is_missing(x) || length(x) != 1 || !is.character(x))
268+
check_pkg_val <- function(pkg) {
269+
if (is_missing(pkg) || length(pkg) != 1 || !is.character(pkg))
243270
stop("Please supply a single character vale for the package name",
244271
call. = FALSE)
245272
invisible(NULL)
@@ -333,23 +360,23 @@ set_model_engine <- function(model, mode, eng) {
333360
#' @rdname get_model_env
334361
#' @keywords internal
335362
#' @export
336-
set_model_arg <- function(model, eng, val, original, func, submodels) {
363+
set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) {
337364
check_mod_val(model, existence = TRUE)
338-
check_arg_val(val)
365+
check_arg_val(parsnip)
339366
check_arg_val(original)
340367
check_func_val(func)
341-
check_submodels_val(submodels)
368+
check_submodels_val(has_submodel)
342369

343370
current <- get_model_env()
344371
old_args <- current[[paste0(model, "_args")]]
345372

346373
new_arg <-
347374
dplyr::tibble(
348375
engine = eng,
349-
parsnip = val,
376+
parsnip = parsnip,
350377
original = original,
351378
func = list(func),
352-
submodels = submodels
379+
has_submodel = has_submodel
353380
)
354381

355382
# TODO cant currently use `distinct()` on a list column.
@@ -359,7 +386,7 @@ set_model_arg <- function(model, eng, val, original, func, submodels) {
359386
stop("An error occured when adding the new argument.", call. = FALSE)
360387
}
361388

362-
updated <- dplyr::distinct(updated, engine, parsnip, original, submodels)
389+
updated <- dplyr::distinct(updated, engine, parsnip, original, has_submodel)
363390

364391
current[[paste0(model, "_args")]] <- updated
365392

R/boost_tree_data.R

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,58 @@ set_dependency("boost_tree", "xgboost", "xgboost")
1212
set_model_arg(
1313
model = "boost_tree",
1414
eng = "xgboost",
15-
val = "tree_depth",
15+
parsnip = "tree_depth",
1616
original = "max_depth",
1717
func = list(pkg = "dials", fun = "tree_depth"),
18-
submodels = FALSE
18+
has_submodel = FALSE
1919
)
2020
set_model_arg(
2121
model = "boost_tree",
2222
eng = "xgboost",
23-
val = "trees",
23+
parsnip = "trees",
2424
original = "nrounds",
2525
func = list(pkg = "dials", fun = "trees"),
26-
submodels = TRUE
26+
has_submodel = TRUE
2727
)
2828
set_model_arg(
2929
model = "boost_tree",
3030
eng = "xgboost",
31-
val = "learn_rate",
31+
parsnip = "learn_rate",
3232
original = "eta",
3333
func = list(pkg = "dials", fun = "learn_rate"),
34-
submodels = FALSE
34+
has_submodel = FALSE
3535
)
3636
set_model_arg(
3737
model = "boost_tree",
3838
eng = "xgboost",
39-
val = "mtry",
39+
parsnip = "mtry",
4040
original = "colsample_bytree",
4141
func = list(pkg = "dials", fun = "mtry"),
42-
submodels = FALSE
42+
has_submodel = FALSE
4343
)
4444
set_model_arg(
4545
model = "boost_tree",
4646
eng = "xgboost",
47-
val = "min_n",
47+
parsnip = "min_n",
4848
original = "min_child_weight",
4949
func = list(pkg = "dials", fun = "min_n"),
50-
submodels = FALSE
50+
has_submodel = FALSE
5151
)
5252
set_model_arg(
5353
model = "boost_tree",
5454
eng = "xgboost",
55-
val = "loss_reduction",
55+
parsnip = "loss_reduction",
5656
original = "gamma",
5757
func = list(pkg = "dials", fun = "loss_reduction"),
58-
submodels = FALSE
58+
has_submodel = FALSE
5959
)
6060
set_model_arg(
6161
model = "boost_tree",
6262
eng = "xgboost",
63-
val = "sample_size",
63+
parsnip = "sample_size",
6464
original = "subsample",
6565
func = list(pkg = "dials", fun = "sample_size"),
66-
submodels = FALSE
66+
has_submodel = FALSE
6767
)
6868

6969
set_fit(
@@ -178,26 +178,26 @@ set_dependency("boost_tree", "C5.0", "C50")
178178
set_model_arg(
179179
model = "boost_tree",
180180
eng = "C5.0",
181-
val = "trees",
181+
parsnip = "trees",
182182
original = "trials",
183183
func = list(pkg = "dials", fun = "trees"),
184-
submodels = TRUE
184+
has_submodel = TRUE
185185
)
186186
set_model_arg(
187187
model = "boost_tree",
188188
eng = "C5.0",
189-
val = "min_n",
189+
parsnip = "min_n",
190190
original = "minCases",
191191
func = list(pkg = "dials", fun = "min_n"),
192-
submodels = FALSE
192+
has_submodel = FALSE
193193
)
194194
set_model_arg(
195195
model = "boost_tree",
196196
eng = "C5.0",
197-
val = "sample_size",
197+
parsnip = "sample_size",
198198
original = "sample",
199199
func = list(pkg = "dials", fun = "sample_size"),
200-
submodels = FALSE
200+
has_submodel = FALSE
201201
)
202202

203203
set_fit(
@@ -268,58 +268,58 @@ set_dependency("boost_tree", "spark", "sparklyr")
268268
set_model_arg(
269269
model = "boost_tree",
270270
eng = "spark",
271-
val = "tree_depth",
271+
parsnip = "tree_depth",
272272
original = "max_depth",
273273
func = list(pkg = "dials", fun = "tree_depth"),
274-
submodels = FALSE
274+
has_submodel = FALSE
275275
)
276276
set_model_arg(
277277
model = "boost_tree",
278278
eng = "spark",
279-
val = "trees",
279+
parsnip = "trees",
280280
original = "max_iter",
281281
func = list(pkg = "dials", fun = "trees"),
282-
submodels = TRUE
282+
has_submodel = TRUE
283283
)
284284
set_model_arg(
285285
model = "boost_tree",
286286
eng = "spark",
287-
val = "learn_rate",
287+
parsnip = "learn_rate",
288288
original = "step_size",
289289
func = list(pkg = "dials", fun = "learn_rate"),
290-
submodels = FALSE
290+
has_submodel = FALSE
291291
)
292292
set_model_arg(
293293
model = "boost_tree",
294294
eng = "spark",
295-
val = "mtry",
295+
parsnip = "mtry",
296296
original = "feature_subset_strategy",
297297
func = list(pkg = "dials", fun = "mtry"),
298-
submodels = FALSE
298+
has_submodel = FALSE
299299
)
300300
set_model_arg(
301301
model = "boost_tree",
302302
eng = "spark",
303-
val = "min_n",
303+
parsnip = "min_n",
304304
original = "min_instances_per_node",
305305
func = list(pkg = "dials", fun = "min_n"),
306-
submodels = FALSE
306+
has_submodel = FALSE
307307
)
308308
set_model_arg(
309309
model = "boost_tree",
310310
eng = "spark",
311-
val = "min_info_gain",
311+
parsnip = "min_info_gain",
312312
original = "gamma",
313313
func = list(pkg = "dials", fun = "loss_reduction"),
314-
submodels = FALSE
314+
has_submodel = FALSE
315315
)
316316
set_model_arg(
317317
model = "boost_tree",
318318
eng = "spark",
319-
val = "sample_size",
319+
parsnip = "sample_size",
320320
original = "subsampling_rate",
321321
func = list(pkg = "dials", fun = "sample_size"),
322-
submodels = FALSE
322+
has_submodel = FALSE
323323
)
324324

325325
set_fit(

0 commit comments

Comments
 (0)