7
7
# ' @param object An object of class `model_fit`
8
8
# ' @param new_data A rectangular data object, such as a data frame.
9
9
# ' @param type A single character value or `NULL`. Possible values
10
- # ' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
11
- # ' or "raw". When `NULL`, `predict()` will choose an appropriate value
12
- # ' based on the model's mode.
10
+ # ' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time ",
11
+ # ' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
12
+ # ' appropriate value based on the model's mode.
13
13
# ' @param opts A list of optional arguments to the underlying
14
14
# ' predict function that will be used when `type = "raw"`. The
15
15
# ' list should not include options for the model object or the
28
28
# ' and "pred_int". Default value is `FALSE`.
29
29
# ' \item `quantile`: the quantile(s) for quantile regression
30
30
# ' (not implemented yet)
31
- # ' \item `time`: the time(s) for hazard probability estimates
32
- # ' (not implemented yet)
31
+ # ' \item `.time`: the time(s) for hazard and survival probability estimates.
33
32
# ' }
34
33
# ' @details If "type" is not supplied to `predict()`, then a choice
35
- # ' is made (`type = "numeric"` for regression models and
36
- # ' `type = "class"` for classification).
34
+ # ' is made:
35
+ # '
36
+ # ' * `type = "numeric"` for regression models,
37
+ # ' * `type = "class"` for classification, and
38
+ # ' * `type = "time"` for censored regression.
37
39
# '
38
40
# ' `predict()` is designed to provide a tidy result (see "Value"
39
41
# ' section below) in a tibble output format.
40
42
# '
43
+ # ' ## Interval predictions
44
+ # '
41
45
# ' When using `type = "conf_int"` and `type = "pred_int"`, the options
42
46
# ' `level` and `std_error` can be used. The latter is a logical for an
43
47
# ' extra column of standard error values (if available).
44
48
# '
49
+ # ' ## Censored regression predictions
50
+ # '
51
+ # ' For censored regression, a numeric vector for `.time` is required when
52
+ # ' survival or hazard probabilities are requested. Also, when
53
+ # ' `type = "linear_pred"`, censored regression models will be formatted such
54
+ # ' that the linear predictor _increases_ with time. This may have the opposite
55
+ # ' sign as what the underlying model's `predict()` method produces.
56
+ # '
45
57
# ' @return With the exception of `type = "raw"`, the results of
46
58
# ' `predict.model_fit()` will be a tibble as many rows in the output
47
59
# ' as there are rows in `new_data` and the column names will be
66
78
# ' Using `type = "raw"` with `predict.model_fit()` will return
67
79
# ' the unadulterated results of the prediction function.
68
80
# '
81
+ # ' For censored regression:
82
+ # '
83
+ # ' * `type = "time"` produces a column `.pred_time`.
84
+ # ' * `type = "hazard"` results in a column `.pred_hazard`.
85
+ # ' * `type = "survival"` results in a column `.pred_survival`.
86
+ # '
87
+ # ' For the last two types, the results are a nested tibble with an overall
88
+ # ' column called `.pred` with sub-tibbles with the above format.
89
+ # '
69
90
# ' In the case of Spark-based models, since table columns cannot
70
91
# ' contain dots, the same convention is used except 1) no dots
71
92
# ' appear in names and 2) vectors are never returned but
108
129
# ' @export predict.model_fit
109
130
# ' @export
110
131
predict.model_fit <- function (object , new_data , type = NULL , opts = list (), ... ) {
111
- the_dots <- enquos(... )
112
- if (any(names(the_dots ) == " newdata" ))
113
- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
114
-
115
132
if (inherits(object $ fit , " try-error" )) {
116
133
rlang :: warn(" Model fit failed; cannot make predictions." )
117
134
return (NULL )
@@ -120,53 +137,54 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
120
137
check_installs(object $ spec )
121
138
load_libs(object $ spec , quiet = TRUE )
122
139
123
- other_args <- c(" level" , " std_error" , " quantile" ) # "time" for survival probs later
124
- is_pred_arg <- names(the_dots ) %in% other_args
125
- if (any(! is_pred_arg )) {
126
- bad_args <- names(the_dots )[! is_pred_arg ]
127
- bad_args <- paste0(" `" , bad_args , " `" , collapse = " , " )
128
- rlang :: abort(
129
- glue :: glue(
130
- " The ellipses are not used to pass args to the model function's " ,
131
- " predict function. These arguments cannot be used: {bad_args}" ,
132
- )
133
- )
134
- }
135
-
136
140
type <- check_pred_type(object , type )
137
- if (type != " raw" && length(opts ) > 0 )
141
+ if (type != " raw" && length(opts ) > 0 ) {
138
142
rlang :: warn(" `opts` is only used with `type = 'raw'` and was ignored." )
143
+ }
144
+ check_pred_type_dots(type , ... )
145
+
139
146
res <- switch (
140
147
type ,
141
- numeric = predict_numeric(object = object , new_data = new_data , ... ),
142
- class = predict_class(object = object , new_data = new_data , ... ),
143
- prob = predict_classprob(object = object , new_data = new_data , ... ),
144
- conf_int = predict_confint(object = object , new_data = new_data , ... ),
145
- pred_int = predict_predint(object = object , new_data = new_data , ... ),
146
- quantile = predict_quantile(object = object , new_data = new_data , ... ),
147
- raw = predict_raw(object = object , new_data = new_data , opts = opts , ... ),
148
+ numeric = predict_numeric(object = object , new_data = new_data , ... ),
149
+ class = predict_class(object = object , new_data = new_data , ... ),
150
+ prob = predict_classprob(object = object , new_data = new_data , ... ),
151
+ conf_int = predict_confint(object = object , new_data = new_data , ... ),
152
+ pred_int = predict_predint(object = object , new_data = new_data , ... ),
153
+ quantile = predict_quantile(object = object , new_data = new_data , ... ),
154
+ time = predict_time(object = object , new_data = new_data , ... ),
155
+ survival = predict_survival(object = object , new_data = new_data , ... ),
156
+ linear_pred = predict_linear_pred(object = object , new_data = new_data , ... ),
157
+ hazard = predict_hazard(object = object , new_data = new_data , ... ),
158
+ raw = predict_raw(object = object , new_data = new_data , opts = opts , ... ),
148
159
rlang :: abort(glue :: glue(" I don't know about type = '{type}'" ))
149
160
)
150
161
if (! inherits(res , " tbl_spark" )) {
151
162
res <- switch (
152
163
type ,
153
- numeric = format_num(res ),
154
- class = format_class(res ),
155
- prob = format_classprobs(res ),
164
+ numeric = format_num(res ),
165
+ class = format_class(res ),
166
+ prob = format_classprobs(res ),
167
+ time = format_time(res ),
168
+ survival = format_survival(res ),
169
+ hazard = format_hazard(res ),
170
+ linear_pred = format_linear_pred(res ),
156
171
res
157
172
)
158
173
}
159
174
res
160
175
}
161
176
177
+ surv_types <- c(" time" , " survival" , " hazard" )
178
+
162
179
# ' @importFrom glue glue_collapse
163
- check_pred_type <- function (object , type ) {
180
+ check_pred_type <- function (object , type , ... ) {
164
181
if (is.null(type )) {
165
182
type <-
166
183
switch (object $ spec $ mode ,
167
184
regression = " numeric" ,
168
185
classification = " class" ,
169
- rlang :: abort(" `type` should be 'regression' or 'classification'." ))
186
+ " censored regression" = " time" ,
187
+ rlang :: abort(" `type` should be 'regression', 'censored regression', or 'classification'." ))
170
188
}
171
189
if (! (type %in% pred_types ))
172
190
rlang :: abort(
@@ -181,6 +199,10 @@ check_pred_type <- function(object, type) {
181
199
rlang :: abort(" For class predictions, the object should be a classification model." )
182
200
if (type == " prob" & object $ spec $ mode != " classification" )
183
201
rlang :: abort(" For probability predictions, the object should be a classification model." )
202
+ if (type %in% surv_types & object $ spec $ mode != " censored regression" )
203
+ rlang :: abort(" For event time predictions, the object should be a censored regression." )
204
+
205
+ # TODO check for ... options when not the correct type
184
206
type
185
207
}
186
208
@@ -216,6 +238,61 @@ format_classprobs <- function(x) {
216
238
x
217
239
}
218
240
241
+ format_time <- function (x ) {
242
+ if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
243
+ x <- as_tibble(x , .name_repair = " minimal" )
244
+ if (! any(grepl(" ^\\ .time" , names(x )))) {
245
+ names(x ) <- paste0(" .time_" , names(x ))
246
+ }
247
+ } else {
248
+ x <- tibble(.pred_time = unname(x ))
249
+ }
250
+
251
+ x
252
+ }
253
+
254
+ format_survival <- function (x ) {
255
+ if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
256
+ x <- as_tibble(x , .name_repair = " minimal" )
257
+ if (! any(grepl(" ^\\ .time" , names(x )))) {
258
+ names(x ) <- paste0(" .time_" , names(x ))
259
+ }
260
+ } else {
261
+ x <- tibble(.pred_survival = unname(x ))
262
+ }
263
+
264
+ x
265
+ }
266
+
267
+ format_linear_pred <- function (x ) {
268
+ if (inherits(x , " tbl_spark" ))
269
+ return (x )
270
+
271
+ if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
272
+ x <- as_tibble(x , .name_repair = " minimal" )
273
+ if (! any(grepl(" ^\\ .time" , names(x )))) {
274
+ names(x ) <- paste0(" .time_" , names(x ))
275
+ }
276
+ } else {
277
+ x <- tibble(.pred_linear_pred = unname(x ))
278
+ }
279
+
280
+ x
281
+ }
282
+
283
+ format_hazard <- function (x ) {
284
+ if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
285
+ x <- as_tibble(x , .name_repair = " minimal" )
286
+ if (! any(grepl(" ^\\ .time" , names(x )))) {
287
+ names(x ) <- paste0(" .time_" , names(x ))
288
+ }
289
+ } else {
290
+ x <- tibble(.pred_hazard = unname(x ))
291
+ }
292
+
293
+ x
294
+ }
295
+
219
296
make_pred_call <- function (x ) {
220
297
if (" pkg" %in% names(x $ func ))
221
298
cl <-
@@ -226,6 +303,54 @@ make_pred_call <- function(x) {
226
303
cl
227
304
}
228
305
306
+ check_pred_type_dots <- function (type , ... ) {
307
+ the_dots <- list (... )
308
+ nms <- names(the_dots )
309
+
310
+ # ----------------------------------------------------------------------------
311
+
312
+ if (any(names(the_dots ) == " newdata" )) {
313
+ rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
314
+ }
315
+
316
+ # ----------------------------------------------------------------------------
317
+
318
+ other_args <- c(" level" , " std_error" , " quantile" , " .time" )
319
+ is_pred_arg <- names(the_dots ) %in% other_args
320
+ if (any(! is_pred_arg )) {
321
+ bad_args <- names(the_dots )[! is_pred_arg ]
322
+ bad_args <- paste0(" `" , bad_args , " `" , collapse = " , " )
323
+ rlang :: abort(
324
+ glue :: glue(
325
+ " The ellipses are not used to pass args to the model function's " ,
326
+ " predict function. These arguments cannot be used: {bad_args}" ,
327
+ )
328
+ )
329
+ }
330
+
331
+ # ----------------------------------------------------------------------------
332
+ # places where .time should not be given
333
+ if (any(nms == " .time" ) & ! type %in% c(" survival" , " hazard" )) {
334
+ rlang :: abort(
335
+ paste(
336
+ " .time should only be passed to `predict()` when 'type' is one of:" ,
337
+ paste0(" '" , c(" survival" , " hazard" ), " '" , collapse = " , " )
338
+ )
339
+ )
340
+ }
341
+ # when .time should be passed
342
+ if (! any(nms == " .time" ) & type %in% c(" survival" , " hazard" )) {
343
+ rlang :: abort(
344
+ paste(
345
+ " When using 'type' values of 'survival' or 'hazard' are given," ,
346
+ " a numeric vector '.time' should also be given."
347
+ )
348
+ )
349
+ }
350
+ invisible (TRUE )
351
+ }
352
+
353
+
229
354
# ' Prepare data based on parsnip encoding information
230
355
# ' @param object A parsnip model object
231
356
# ' @param new_data A data frame
0 commit comments