Skip to content

Commit 39bf71e

Browse files
committed
template for prediction objects
1 parent 0c90d72 commit 39bf71e

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ export(nearest_neighbor)
104104
export(null_model)
105105
export(nullmodel)
106106
export(pred_types)
107+
export(pred_value_template)
107108
export(predict.model_fit)
108109
export(rand_forest)
109110
export(rpart_train)

R/aaa_models.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ pred_types <-
7878
#' @param value A list that conforms to the `fit_obj` or `pred_obj` description
7979
#' above, depending on context.
8080
#' @param items A character string of objects in the model environment.
81+
#' @param pre,post Optional functions for pre- and post-processing of prediction
82+
#' results.
83+
#' @param ... Optional arguments that should be passed into the `args` slot for
84+
#' prediction objects
8185
#' @keywords internal
8286
#' @details These functions are available for users to add their
8387
#' own models or engines (in package or otherwise) so that they can
@@ -692,3 +696,15 @@ get_from_env <- function(items) {
692696
rlang::env_get(mod_env, items)
693697
}
694698

699+
# ------------------------------------------------------------------------------
700+
701+
#' @rdname get_model_env
702+
#' @keywords internal
703+
#' @export
704+
pred_value_template <- function(pre = NULL, post = NULL, func, ...) {
705+
if (rlang::is_missing(func)) {
706+
stop("Please supply a value to `func`. See `?set_pred`.", call. = FALSE)
707+
}
708+
list(pre = pre, post = post, func = func, args = list(...))
709+
}
710+

man/get_model_env.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vignettes/articles/Scratch.Rmd

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,19 @@ A similar call can be used to define the class probability module (if they can
210210

211211
As an example of the `post` function, the data frame created by `mda:::predict.mda` will be converted to a tibble. The arguments are `x` (the raw results coming from the predict method) and `object` (the `parsnip` model fit object). The latter has a sub-object called `lvl` which is a character string of the outcome's factor levels (if any).
212212

213-
We register the probability module:
213+
We register the probability module. There is a template function that makes this slightly easier to format the objects:
214214

215215
```{r mda-prob}
216216
prob_info <-
217-
list(
218-
pre = NULL,
217+
pred_value_template(
219218
post = function(x, object) {
220-
tibble::as_tibble(x)
221-
},
219+
tibble::as_tibble(x)
220+
},
222221
func = c(fun = "predict"),
223-
args =
224-
list(
225-
object = quote(object$fit),
226-
newdata = quote(new_data),
227-
type = "posterior"
228-
)
222+
# Now everything else is put into the `args` slot
223+
object = quote(object$fit),
224+
newdata = quote(new_data),
225+
type = "posterior"
229226
)
230227
231228
set_pred(

0 commit comments

Comments
 (0)