Skip to content

Commit bd3953e

Browse files
committed
updated checks and documentation for censored regression models
1 parent 435f07d commit bd3953e

File tree

5 files changed

+129
-44
lines changed

5 files changed

+129
-44
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Imports:
3232
prettyunits,
3333
vctrs (>= 0.2.0)
3434
Roxygen: list(markdown = TRUE)
35-
RoxygenNote: 7.1.1.9000
35+
RoxygenNote: 7.1.1.9001
3636
Suggests:
3737
testthat,
3838
knitr,

R/predict.R

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#' @param object An object of class `model_fit`
88
#' @param new_data A rectangular data object, such as a data frame.
99
#' @param type A single character value or `NULL`. Possible values
10-
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
11-
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
12-
#' based on the model's mode.
10+
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time",
11+
#' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
12+
#' appropriate value based on the model's mode.
1313
#' @param opts A list of optional arguments to the underlying
1414
#' predict function that will be used when `type = "raw"`. The
1515
#' list should not include options for the model object or the
@@ -28,20 +28,32 @@
2828
#' and "pred_int". Default value is `FALSE`.
2929
#' \item `quantile`: the quantile(s) for quantile regression
3030
#' (not implemented yet)
31-
#' \item `time`: the time(s) for hazard probability estimates
32-
#' (not implemented yet)
31+
#' \item `.time`: the time(s) for hazard and survival probability estimates.
3332
#' }
3433
#' @details If "type" is not supplied to `predict()`, then a choice
35-
#' is made (`type = "numeric"` for regression models and
36-
#' `type = "class"` for classification).
34+
#' is made:
35+
#'
36+
#' * `type = "numeric"` for regression models,
37+
#' * `type = "class"` for classification, and
38+
#' * `type = "time"` for censored regression.
3739
#'
3840
#' `predict()` is designed to provide a tidy result (see "Value"
3941
#' section below) in a tibble output format.
4042
#'
43+
#' ## Interval predictions
44+
#'
4145
#' When using `type = "conf_int"` and `type = "pred_int"`, the options
4246
#' `level` and `std_error` can be used. The latter is a logical for an
4347
#' extra column of standard error values (if available).
4448
#'
49+
#' ## Censored regression predictions
50+
#'
51+
#' For censored regression, a numeric vector for `.time` is required when
52+
#' survival or hazard probabilities are requested. Also, when
53+
#' `type = "linear_pred"`, censored regression models will be formatted such
54+
#' that the linear predictor _increases_ with time. This may have the opposite
55+
#' sign as what the underlying model's `predict()` method produces.
56+
#'
4557
#' @return With the exception of `type = "raw"`, the results of
4658
#' `predict.model_fit()` will be a tibble as many rows in the output
4759
#' as there are rows in `new_data` and the column names will be
@@ -66,6 +78,15 @@
6678
#' Using `type = "raw"` with `predict.model_fit()` will return
6779
#' the unadulterated results of the prediction function.
6880
#'
81+
#' For censored regression:
82+
#'
83+
#' * `type = "time"` produces a column `.pred_time`.
84+
#' * `type = "hazard"` results in a column `.pred_hazard`.
85+
#' * `type = "survival"` results in a column `.pred_survival`.
86+
#'
87+
#' For the last two types, the results are a nested tibble with an overall
88+
#' column called `.pred` with sub-tibbles with the above format.
89+
#'
6990
#' In the case of Spark-based models, since table columns cannot
7091
#' contain dots, the same convention is used except 1) no dots
7192
#' appear in names and 2) vectors are never returned but
@@ -108,10 +129,6 @@
108129
#' @export predict.model_fit
109130
#' @export
110131
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
111-
the_dots <- enquos(...)
112-
if (any(names(the_dots) == "newdata"))
113-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
114-
115132
if (inherits(object$fit, "try-error")) {
116133
rlang::warn("Model fit failed; cannot make predictions.")
117134
return(NULL)
@@ -120,22 +137,12 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
120137
check_installs(object$spec)
121138
load_libs(object$spec, quiet = TRUE)
122139

123-
other_args <- c("level", "std_error", "quantile", ".time")
124-
is_pred_arg <- names(the_dots) %in% other_args
125-
if (any(!is_pred_arg)) {
126-
bad_args <- names(the_dots)[!is_pred_arg]
127-
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
128-
rlang::abort(
129-
glue::glue(
130-
"The ellipses are not used to pass args to the model function's ",
131-
"predict function. These arguments cannot be used: {bad_args}",
132-
)
133-
)
134-
}
135-
136140
type <- check_pred_type(object, type)
137-
if (type != "raw" && length(opts) > 0)
141+
if (type != "raw" && length(opts) > 0) {
138142
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
143+
}
144+
check_pred_type_dots(type, ...)
145+
139146
res <- switch(
140147
type,
141148
numeric = predict_numeric(object = object, new_data = new_data, ...),
@@ -167,15 +174,17 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
167174
res
168175
}
169176

177+
surv_types <- c("time", "survival", "hazard")
178+
170179
#' @importFrom glue glue_collapse
171-
check_pred_type <- function(object, type) {
180+
check_pred_type <- function(object, type, ...) {
172181
if (is.null(type)) {
173182
type <-
174183
switch(object$spec$mode,
175184
regression = "numeric",
176185
classification = "class",
177186
"censored regression" = "time",
178-
rlang::abort("`type` should be 'regression' or 'classification'."))
187+
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
179188
}
180189
if (!(type %in% pred_types))
181190
rlang::abort(
@@ -190,6 +199,10 @@ check_pred_type <- function(object, type) {
190199
rlang::abort("For class predictions, the object should be a classification model.")
191200
if (type == "prob" & object$spec$mode != "classification")
192201
rlang::abort("For probability predictions, the object should be a classification model.")
202+
if (type %in% surv_types & object$spec$mode != "censored regression")
203+
rlang::abort("For event time predictions, the object should be a censored regression.")
204+
205+
# TODO check for ... options when not the correct type
193206
type
194207
}
195208

@@ -290,6 +303,54 @@ make_pred_call <- function(x) {
290303
cl
291304
}
292305

306+
check_pred_type_dots <- function(type, ...) {
307+
the_dots <- list(...)
308+
nms <- names(the_dots)
309+
310+
# ----------------------------------------------------------------------------
311+
312+
if (any(names(the_dots) == "newdata")) {
313+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
314+
}
315+
316+
# ----------------------------------------------------------------------------
317+
318+
other_args <- c("level", "std_error", "quantile", ".time")
319+
is_pred_arg <- names(the_dots) %in% other_args
320+
if (any(!is_pred_arg)) {
321+
bad_args <- names(the_dots)[!is_pred_arg]
322+
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
323+
rlang::abort(
324+
glue::glue(
325+
"The ellipses are not used to pass args to the model function's ",
326+
"predict function. These arguments cannot be used: {bad_args}",
327+
)
328+
)
329+
}
330+
331+
# ----------------------------------------------------------------------------
332+
# places where .time should not be given
333+
if (any(nms == ".time") & !type %in% c("survival", "hazard")) {
334+
rlang::abort(
335+
paste(
336+
".time should only be passed to `predict()` when 'type' is one of:",
337+
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
338+
)
339+
)
340+
}
341+
# when .time should be passed
342+
if (!any(nms == ".time") & type %in% c("survival", "hazard")) {
343+
rlang::abort(
344+
paste(
345+
"When using 'type' values of 'survival'' or 'hazard' are given,",
346+
"a numeric vector '.time' should also be given."
347+
)
348+
)
349+
}
350+
invisible(TRUE)
351+
}
352+
353+
293354
#' Prepare data based on parsnip encoding information
294355
#' @param object A parsnip model object
295356
#' @param new_data A data frame

man/other_predict.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict.model_fit.Rd

Lines changed: 31 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rand_forest.Rd

Lines changed: 8 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)