@@ -30,6 +30,9 @@ parsnip$modes <- c("regression", "classification", "unknown")
30
30
31
31
# ------------------------------------------------------------------------------
32
32
33
+ # ' @rdname check_mod_val
34
+ # ' @keywords internal
35
+ # ' @export
33
36
pred_types <-
34
37
c(" raw" , " numeric" , " class" , " link" , " prob" , " conf_int" , " pred_int" , " quantile" )
35
38
@@ -62,6 +65,30 @@ get_model_env <- function() {
62
65
# ' @param mode A single character string for the model mode (e.g. "regression").
63
66
# ' @param eng A single character string for the model engine.
64
67
# ' @param arg A single character string for the model argument name.
68
+ # ' @param has_submodel A single logical for whether the argument
69
+ # ' can make predictions on mutiple submodels at once.
70
+ # ' @param func A named character vector that describes how to call
71
+ # ' a function. `func` should have elements `pkg` and `fun`. The
72
+ # ' former is optional but is recommended and the latter is
73
+ # ' required. For example, `c(pkg = "stats", fun = "lm")` would be
74
+ # ' used to invoke the usual linear regression function. In some
75
+ # ' cases, it is helpful to use `c(fun = "predict")` when using a
76
+ # ' package's `predict` method.
77
+ # ' @param fit_obj A list with elements `interface`, `protect`,
78
+ # ' `func` and `defaults`. See the package vignette "Making a
79
+ # ' `parsnip` model from scratch".
80
+ # ' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`.
81
+ # ' See the package vignette "Making a `parsnip` model from scratch".
82
+ # ' @param type A single character value for the type of prediction. Possible
83
+ # ' values are:
84
+ # ' \Sexpr[results=rd]{paste0("'", parsnip::pred_types, "'", collapse = ", ")}.
85
+ # ' @param pkg An options character string for a package name.
86
+ # ' @param parsnip A single character string for the "harmonized" argument name
87
+ # ' that `parsnip` exposes.
88
+ # ' @param original A single character string for the argument name that
89
+ # ' underlying model function uses.
90
+ # ' @param value A list that conforms to the `fit_obj` or `pred_obj` description
91
+ # ' above, depending on context.
65
92
# ' @keywords internal
66
93
# ' @export
67
94
check_mod_val <- function (model , new = FALSE , existence = FALSE ) {
@@ -122,8 +149,8 @@ check_arg_val <- function(arg) {
122
149
# ' @rdname check_mod_val
123
150
# ' @keywords internal
124
151
# ' @export
125
- check_submodels_val <- function (x ) {
126
- if (! is.logical(x ) || length(x ) != 1 ) {
152
+ check_submodels_val <- function (has_submodel ) {
153
+ if (! is.logical(has_submodel ) || length(has_submodel ) != 1 ) {
127
154
stop(" The `submodels` argument should be a single logical." , call. = FALSE )
128
155
}
129
156
invisible (NULL )
@@ -169,31 +196,31 @@ check_func_val <- function(func) {
169
196
# ' @rdname check_mod_val
170
197
# ' @keywords internal
171
198
# ' @export
172
- check_fit_info <- function (x ) {
173
- if (is.null(x )) {
199
+ check_fit_info <- function (fit_obj ) {
200
+ if (is.null(fit_obj )) {
174
201
stop(" The `fit` module cannot be NULL." , call. = FALSE )
175
202
}
176
203
exp_nms <- c(" defaults" , " func" , " interface" , " protect" )
177
- if (! isTRUE(all.equal(sort(names(x )), exp_nms ))) {
204
+ if (! isTRUE(all.equal(sort(names(fit_obj )), exp_nms ))) {
178
205
stop(" The `fit` module should have elements: " ,
179
206
paste0(" `" , exp_nms , " `" , collapse = " , " ),
180
207
call. = FALSE )
181
208
}
182
209
183
210
exp_interf <- c(" data.frame" , " formula" , " matrix" )
184
- if (length(x $ interface ) > 1 ) {
211
+ if (length(fit_obj $ interface ) > 1 ) {
185
212
stop(" The `interface` element should have a single value of : " ,
186
213
paste0(" `" , exp_interf , " `" , collapse = " , " ),
187
214
call. = FALSE )
188
215
}
189
- if (! any(x $ interface == exp_interf )) {
216
+ if (! any(fit_obj $ interface == exp_interf )) {
190
217
stop(" The `interface` element should have a value of : " ,
191
218
paste0(" `" , exp_interf , " `" , collapse = " , " ),
192
219
call. = FALSE )
193
220
}
194
- check_func_val(x $ func )
221
+ check_func_val(fit_obj $ func )
195
222
196
- if (! is.list(x $ defaults )) {
223
+ if (! is.list(fit_obj $ defaults )) {
197
224
stop(" The `defaults` element should be a list: " , call. = FALSE )
198
225
}
199
226
@@ -203,32 +230,32 @@ check_fit_info <- function(x) {
203
230
# ' @rdname check_mod_val
204
231
# ' @keywords internal
205
232
# ' @export
206
- check_pred_info <- function (x , type ) {
233
+ check_pred_info <- function (pred_obj , type ) {
207
234
if (all(type != pred_types )) {
208
235
stop(" The prediction type should be one of: " ,
209
236
paste0(" '" , pred_types , " '" , collapse = " , " ),
210
237
call. = FALSE )
211
238
}
212
239
213
240
exp_nms <- c(" args" , " func" , " post" , " pre" )
214
- if (! isTRUE(all.equal(sort(names(x )), exp_nms ))) {
241
+ if (! isTRUE(all.equal(sort(names(pred_obj )), exp_nms ))) {
215
242
stop(" The `predict` module should have elements: " ,
216
243
paste0(" `" , exp_nms , " `" , collapse = " , " ),
217
244
call. = FALSE )
218
245
}
219
246
220
- if (! is.null(x $ pre ) & ! is.function(x $ pre )) {
247
+ if (! is.null(pred_obj $ pre ) & ! is.function(pred_obj $ pre )) {
221
248
stop(" The `pre` module should be null or a function: " ,
222
249
call. = FALSE )
223
250
}
224
- if (! is.null(x $ post ) & ! is.function(x $ post )) {
251
+ if (! is.null(pred_obj $ post ) & ! is.function(pred_obj $ post )) {
225
252
stop(" The `post` module should be null or a function: " ,
226
253
call. = FALSE )
227
254
}
228
255
229
- check_func_val(x $ func )
256
+ check_func_val(pred_obj $ func )
230
257
231
- if (! is.list(x $ args )) {
258
+ if (! is.list(pred_obj $ args )) {
232
259
stop(" The `args` element should be a list. " , call. = FALSE )
233
260
}
234
261
@@ -238,8 +265,8 @@ check_pred_info <- function(x, type) {
238
265
# ' @rdname check_mod_val
239
266
# ' @keywords internal
240
267
# ' @export
241
- check_pkg_val <- function (x ) {
242
- if (is_missing(x ) || length(x ) != 1 || ! is.character(x ))
268
+ check_pkg_val <- function (pkg ) {
269
+ if (is_missing(pkg ) || length(pkg ) != 1 || ! is.character(pkg ))
243
270
stop(" Please supply a single character vale for the package name" ,
244
271
call. = FALSE )
245
272
invisible (NULL )
@@ -333,23 +360,23 @@ set_model_engine <- function(model, mode, eng) {
333
360
# ' @rdname get_model_env
334
361
# ' @keywords internal
335
362
# ' @export
336
- set_model_arg <- function (model , eng , val , original , func , submodels ) {
363
+ set_model_arg <- function (model , eng , parsnip , original , func , has_submodel ) {
337
364
check_mod_val(model , existence = TRUE )
338
- check_arg_val(val )
365
+ check_arg_val(parsnip )
339
366
check_arg_val(original )
340
367
check_func_val(func )
341
- check_submodels_val(submodels )
368
+ check_submodels_val(has_submodel )
342
369
343
370
current <- get_model_env()
344
371
old_args <- current [[paste0(model , " _args" )]]
345
372
346
373
new_arg <-
347
374
dplyr :: tibble(
348
375
engine = eng ,
349
- parsnip = val ,
376
+ parsnip = parsnip ,
350
377
original = original ,
351
378
func = list (func ),
352
- submodels = submodels
379
+ has_submodel = has_submodel
353
380
)
354
381
355
382
# TODO cant currently use `distinct()` on a list column.
@@ -359,7 +386,7 @@ set_model_arg <- function(model, eng, val, original, func, submodels) {
359
386
stop(" An error occured when adding the new argument." , call. = FALSE )
360
387
}
361
388
362
- updated <- dplyr :: distinct(updated , engine , parsnip , original , submodels )
389
+ updated <- dplyr :: distinct(updated , engine , parsnip , original , has_submodel )
363
390
364
391
current [[paste0(model , " _args" )]] <- updated
365
392
0 commit comments