Skip to content

Commit 963ba95

Browse files
authored
Merge pull request #396 from EmilHvitfeldt/survnip-integration
adding new prediction types for Survnip
2 parents 154c1ab + c0c0191 commit 963ba95

12 files changed

+419
-57
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.1.5.9000
2+
Version: 0.1.5.9001
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NAMESPACE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ S3method(predict_classprob,"_lognet")
3232
S3method(predict_classprob,"_multnet")
3333
S3method(predict_classprob,model_fit)
3434
S3method(predict_confint,model_fit)
35+
S3method(predict_hazard,model_fit)
36+
S3method(predict_linear_pred,model_fit)
3537
S3method(predict_numeric,"_elnet")
3638
S3method(predict_numeric,model_fit)
3739
S3method(predict_quantile,model_fit)
3840
S3method(predict_raw,"_elnet")
3941
S3method(predict_raw,"_lognet")
4042
S3method(predict_raw,"_multnet")
4143
S3method(predict_raw,model_fit)
44+
S3method(predict_survival,model_fit)
45+
S3method(predict_time,model_fit)
4246
S3method(print,boost_tree)
4347
S3method(print,control_parsnip)
4448
S3method(print,decision_tree)
@@ -156,11 +160,17 @@ export(predict.model_fit)
156160
export(predict_class.model_fit)
157161
export(predict_classprob.model_fit)
158162
export(predict_confint.model_fit)
163+
export(predict_hazard.model_fit)
164+
export(predict_linear_pred)
165+
export(predict_linear_pred.model_fit)
159166
export(predict_numeric)
160167
export(predict_numeric.model_fit)
161168
export(predict_quantile.model_fit)
162169
export(predict_raw)
163170
export(predict_raw.model_fit)
171+
export(predict_survival.model_fit)
172+
export(predict_time)
173+
export(predict_time.model_fit)
164174
export(prepare_data)
165175
export(rand_forest)
166176
export(repair_call)

R/aaa_models.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
3131
# ------------------------------------------------------------------------------
3232

3333
pred_types <-
34-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
34+
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
35+
"time", "survival", "linear_pred", "hazard")
3536

3637
# ------------------------------------------------------------------------------
3738

R/misc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ check_empty_ellipse <- function (...) {
2323
terms
2424
}
2525

26-
all_modes <- c("classification", "regression")
26+
all_modes <- c("classification", "regression", "censored regression")
2727

2828

2929
deparserizer <- function(x, limit = options()$width - 10) {

R/predict.R

Lines changed: 162 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#' @param object An object of class `model_fit`
88
#' @param new_data A rectangular data object, such as a data frame.
99
#' @param type A single character value or `NULL`. Possible values
10-
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
11-
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
12-
#' based on the model's mode.
10+
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time",
11+
#' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
12+
#' appropriate value based on the model's mode.
1313
#' @param opts A list of optional arguments to the underlying
1414
#' predict function that will be used when `type = "raw"`. The
1515
#' list should not include options for the model object or the
@@ -28,20 +28,32 @@
2828
#' and "pred_int". Default value is `FALSE`.
2929
#' \item `quantile`: the quantile(s) for quantile regression
3030
#' (not implemented yet)
31-
#' \item `time`: the time(s) for hazard probability estimates
32-
#' (not implemented yet)
31+
#' \item `.time`: the time(s) for hazard and survival probability estimates.
3332
#' }
3433
#' @details If "type" is not supplied to `predict()`, then a choice
35-
#' is made (`type = "numeric"` for regression models and
36-
#' `type = "class"` for classification).
34+
#' is made:
35+
#'
36+
#' * `type = "numeric"` for regression models,
37+
#' * `type = "class"` for classification, and
38+
#' * `type = "time"` for censored regression.
3739
#'
3840
#' `predict()` is designed to provide a tidy result (see "Value"
3941
#' section below) in a tibble output format.
4042
#'
43+
#' ## Interval predictions
44+
#'
4145
#' When using `type = "conf_int"` and `type = "pred_int"`, the options
4246
#' `level` and `std_error` can be used. The latter is a logical for an
4347
#' extra column of standard error values (if available).
4448
#'
49+
#' ## Censored regression predictions
50+
#'
51+
#' For censored regression, a numeric vector for `.time` is required when
52+
#' survival or hazard probabilities are requested. Also, when
53+
#' `type = "linear_pred"`, censored regression models will be formatted such
54+
#' that the linear predictor _increases_ with time. This may have the opposite
55+
#' sign as what the underlying model's `predict()` method produces.
56+
#'
4557
#' @return With the exception of `type = "raw"`, the results of
4658
#' `predict.model_fit()` will be a tibble as many rows in the output
4759
#' as there are rows in `new_data` and the column names will be
@@ -66,6 +78,15 @@
6678
#' Using `type = "raw"` with `predict.model_fit()` will return
6779
#' the unadulterated results of the prediction function.
6880
#'
81+
#' For censored regression:
82+
#'
83+
#' * `type = "time"` produces a column `.pred_time`.
84+
#' * `type = "hazard"` results in a column `.pred_hazard`.
85+
#' * `type = "survival"` results in a column `.pred_survival`.
86+
#'
87+
#' For the last two types, the results are a nested tibble with an overall
88+
#' column called `.pred` with sub-tibbles with the above format.
89+
#'
6990
#' In the case of Spark-based models, since table columns cannot
7091
#' contain dots, the same convention is used except 1) no dots
7192
#' appear in names and 2) vectors are never returned but
@@ -108,10 +129,6 @@
108129
#' @export predict.model_fit
109130
#' @export
110131
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
111-
the_dots <- enquos(...)
112-
if (any(names(the_dots) == "newdata"))
113-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
114-
115132
if (inherits(object$fit, "try-error")) {
116133
rlang::warn("Model fit failed; cannot make predictions.")
117134
return(NULL)
@@ -120,53 +137,54 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
120137
check_installs(object$spec)
121138
load_libs(object$spec, quiet = TRUE)
122139

123-
other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
124-
is_pred_arg <- names(the_dots) %in% other_args
125-
if (any(!is_pred_arg)) {
126-
bad_args <- names(the_dots)[!is_pred_arg]
127-
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
128-
rlang::abort(
129-
glue::glue(
130-
"The ellipses are not used to pass args to the model function's ",
131-
"predict function. These arguments cannot be used: {bad_args}",
132-
)
133-
)
134-
}
135-
136140
type <- check_pred_type(object, type)
137-
if (type != "raw" && length(opts) > 0)
141+
if (type != "raw" && length(opts) > 0) {
138142
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
143+
}
144+
check_pred_type_dots(type, ...)
145+
139146
res <- switch(
140147
type,
141-
numeric = predict_numeric(object = object, new_data = new_data, ...),
142-
class = predict_class(object = object, new_data = new_data, ...),
143-
prob = predict_classprob(object = object, new_data = new_data, ...),
144-
conf_int = predict_confint(object = object, new_data = new_data, ...),
145-
pred_int = predict_predint(object = object, new_data = new_data, ...),
146-
quantile = predict_quantile(object = object, new_data = new_data, ...),
147-
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
148+
numeric = predict_numeric(object = object, new_data = new_data, ...),
149+
class = predict_class(object = object, new_data = new_data, ...),
150+
prob = predict_classprob(object = object, new_data = new_data, ...),
151+
conf_int = predict_confint(object = object, new_data = new_data, ...),
152+
pred_int = predict_predint(object = object, new_data = new_data, ...),
153+
quantile = predict_quantile(object = object, new_data = new_data, ...),
154+
time = predict_time(object = object, new_data = new_data, ...),
155+
survival = predict_survival(object = object, new_data = new_data, ...),
156+
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
157+
hazard = predict_hazard(object = object, new_data = new_data, ...),
158+
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
148159
rlang::abort(glue::glue("I don't know about type = '{type}'"))
149160
)
150161
if (!inherits(res, "tbl_spark")) {
151162
res <- switch(
152163
type,
153-
numeric = format_num(res),
154-
class = format_class(res),
155-
prob = format_classprobs(res),
164+
numeric = format_num(res),
165+
class = format_class(res),
166+
prob = format_classprobs(res),
167+
time = format_time(res),
168+
survival = format_survival(res),
169+
hazard = format_hazard(res),
170+
linear_pred = format_linear_pred(res),
156171
res
157172
)
158173
}
159174
res
160175
}
161176

177+
surv_types <- c("time", "survival", "hazard")
178+
162179
#' @importFrom glue glue_collapse
163-
check_pred_type <- function(object, type) {
180+
check_pred_type <- function(object, type, ...) {
164181
if (is.null(type)) {
165182
type <-
166183
switch(object$spec$mode,
167184
regression = "numeric",
168185
classification = "class",
169-
rlang::abort("`type` should be 'regression' or 'classification'."))
186+
"censored regression" = "time",
187+
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
170188
}
171189
if (!(type %in% pred_types))
172190
rlang::abort(
@@ -181,6 +199,10 @@ check_pred_type <- function(object, type) {
181199
rlang::abort("For class predictions, the object should be a classification model.")
182200
if (type == "prob" & object$spec$mode != "classification")
183201
rlang::abort("For probability predictions, the object should be a classification model.")
202+
if (type %in% surv_types & object$spec$mode != "censored regression")
203+
rlang::abort("For event time predictions, the object should be a censored regression.")
204+
205+
# TODO check for ... options when not the correct type
184206
type
185207
}
186208

@@ -216,6 +238,61 @@ format_classprobs <- function(x) {
216238
x
217239
}
218240

241+
format_time <- function(x) {
242+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
243+
x <- as_tibble(x, .name_repair = "minimal")
244+
if (!any(grepl("^\\.time", names(x)))) {
245+
names(x) <- paste0(".time_", names(x))
246+
}
247+
} else {
248+
x <- tibble(.pred_time = unname(x))
249+
}
250+
251+
x
252+
}
253+
254+
format_survival <- function(x) {
255+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
256+
x <- as_tibble(x, .name_repair = "minimal")
257+
if (!any(grepl("^\\.time", names(x)))) {
258+
names(x) <- paste0(".time_", names(x))
259+
}
260+
} else {
261+
x <- tibble(.pred_survival = unname(x))
262+
}
263+
264+
x
265+
}
266+
267+
format_linear_pred <- function(x) {
268+
if (inherits(x, "tbl_spark"))
269+
return(x)
270+
271+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
272+
x <- as_tibble(x, .name_repair = "minimal")
273+
if (!any(grepl("^\\.time", names(x)))) {
274+
names(x) <- paste0(".time_", names(x))
275+
}
276+
} else {
277+
x <- tibble(.pred_linear_pred = unname(x))
278+
}
279+
280+
x
281+
}
282+
283+
format_hazard <- function(x) {
284+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
285+
x <- as_tibble(x, .name_repair = "minimal")
286+
if (!any(grepl("^\\.time", names(x)))) {
287+
names(x) <- paste0(".time_", names(x))
288+
}
289+
} else {
290+
x <- tibble(.pred_hazard = unname(x))
291+
}
292+
293+
x
294+
}
295+
219296
make_pred_call <- function(x) {
220297
if ("pkg" %in% names(x$func))
221298
cl <-
@@ -226,6 +303,54 @@ make_pred_call <- function(x) {
226303
cl
227304
}
228305

306+
check_pred_type_dots <- function(type, ...) {
307+
the_dots <- list(...)
308+
nms <- names(the_dots)
309+
310+
# ----------------------------------------------------------------------------
311+
312+
if (any(names(the_dots) == "newdata")) {
313+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
314+
}
315+
316+
# ----------------------------------------------------------------------------
317+
318+
other_args <- c("level", "std_error", "quantile", ".time")
319+
is_pred_arg <- names(the_dots) %in% other_args
320+
if (any(!is_pred_arg)) {
321+
bad_args <- names(the_dots)[!is_pred_arg]
322+
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
323+
rlang::abort(
324+
glue::glue(
325+
"The ellipses are not used to pass args to the model function's ",
326+
"predict function. These arguments cannot be used: {bad_args}",
327+
)
328+
)
329+
}
330+
331+
# ----------------------------------------------------------------------------
332+
# places where .time should not be given
333+
if (any(nms == ".time") & !type %in% c("survival", "hazard")) {
334+
rlang::abort(
335+
paste(
336+
".time should only be passed to `predict()` when 'type' is one of:",
337+
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
338+
)
339+
)
340+
}
341+
# when .time should be passed
342+
if (!any(nms == ".time") & type %in% c("survival", "hazard")) {
343+
rlang::abort(
344+
paste(
345+
"When using 'type' values of 'survival' or 'hazard' are given,",
346+
"a numeric vector '.time' should also be given."
347+
)
348+
)
349+
}
350+
invisible(TRUE)
351+
}
352+
353+
229354
#' Prepare data based on parsnip encoding information
230355
#' @param object A parsnip model object
231356
#' @param new_data A data frame

R/predict_hazard.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @inheritParams predict.model_fit
4+
#' @method predict_hazard model_fit
5+
#' @export predict_hazard.model_fit
6+
#' @export
7+
predict_hazard.model_fit <-
8+
function(object, new_data, .time, ...) {
9+
10+
if (is.null(object$spec$method$pred$hazard))
11+
rlang::abort("No hazard prediction method defined for this engine.")
12+
13+
if (inherits(object$fit, "try-error")) {
14+
rlang::warn("Model fit failed; cannot make predictions.")
15+
return(NULL)
16+
}
17+
18+
new_data <- prepare_data(object, new_data)
19+
20+
# preprocess data
21+
if (!is.null(object$spec$method$pred$hazard$pre))
22+
new_data <- object$spec$method$pred$hazard$pre(new_data, object)
23+
24+
# Pass some extra arguments to be used in post-processor
25+
object$spec$method$pred$hazard$args$.time <- .time
26+
pred_call <- make_pred_call(object$spec$method$pred$hazard)
27+
28+
res <- eval_tidy(pred_call)
29+
30+
# post-process the predictions
31+
if(!is.null(object$spec$method$pred$hazard$post)) {
32+
res <- object$spec$method$pred$hazard$post(res, object)
33+
}
34+
35+
res
36+
}
37+
38+
# @export
39+
# @keywords internal
40+
# @rdname other_predict
41+
# @inheritParams predict.model_fit
42+
predict_hazard <- function (object, ...)
43+
UseMethod("predict_hazard")

0 commit comments

Comments
 (0)