147
147
# ' @export
148
148
predict.model_fit <- function (object , new_data , type = NULL , opts = list (), ... ) {
149
149
if (inherits(object $ fit , " try-error" )) {
150
- rlang :: warn (" Model fit failed; cannot make predictions." )
150
+ cli :: cli_warn (" Model fit failed; cannot make predictions." )
151
151
return (NULL )
152
152
}
153
153
@@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
156
156
157
157
type <- check_pred_type(object , type )
158
158
if (type != " raw" && length(opts ) > 0 ) {
159
- rlang :: warn( " ` opts` is only used with `type = 'raw'` and was ignored." )
159
+ cli :: cli_warn( " {.arg opts} is only used with `type = 'raw'` and was ignored." )
160
160
}
161
161
check_pred_type_dots(object , type , ... )
162
162
@@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
173
173
linear_pred = predict_linear_pred(object = object , new_data = new_data , ... ),
174
174
hazard = predict_hazard(object = object , new_data = new_data , ... ),
175
175
raw = predict_raw(object = object , new_data = new_data , opts = opts , ... ),
176
- rlang :: abort( glue :: glue( " I don't know about type = '{type}'" ) )
176
+ cli :: cli_abort( " Unknown prediction {.arg type} '{type}'. " )
177
177
)
178
178
if (! inherits(res , " tbl_spark" )) {
179
179
res <- switch (
@@ -191,45 +191,69 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
191
191
res
192
192
}
193
193
194
- check_pred_type <- function (object , type , ... ) {
194
+ check_pred_type <- function (object , type , ... , call = rlang :: caller_env() ) {
195
195
if (is.null(type )) {
196
196
type <-
197
- switch (object $ spec $ mode ,
198
- regression = " numeric" ,
199
- classification = " class" ,
200
- " censored regression" = " time" ,
201
- rlang :: abort(" `type` should be 'regression', 'censored regression', or 'classification'." ))
197
+ switch (
198
+ object $ spec $ mode ,
199
+ regression = " numeric" ,
200
+ classification = " class" ,
201
+ " censored regression" = " time" ,
202
+ cli :: cli_abort(
203
+ " {.arg type} should be 'regression', 'censored regression', or 'classification'." ,
204
+ call = call
205
+ )
206
+ )
202
207
}
203
208
if (! (type %in% pred_types ))
204
- rlang :: abort(
205
- glue :: glue(
206
- " `type` should be one of: " ,
207
- glue_collapse(pred_types , sep = " , " , last = " and " )
208
- )
209
+ cli :: cli_abort(
210
+ " {.arg type} should be one of:{.arg {pred_types}}" ,
211
+ call = call
209
212
)
210
213
211
214
switch (
212
215
type ,
213
216
" numeric" = if (object $ spec $ mode != " regression" ) {
214
- rlang :: abort(" For numeric predictions, the object should be a regression model." )
217
+ cli :: cli_abort(
218
+ " For numeric predictions, the object should be a regression model." ,
219
+ call = call
220
+ )
215
221
},
216
222
" class" = if (object $ spec $ mode != " classification" ) {
217
- rlang :: abort(" For class predictions, the object should be a classification model." )
223
+ cli :: cli_abort(
224
+ " For class predictions, the object should be a classification model." ,
225
+ call = call
226
+ )
218
227
},
219
228
" prob" = if (object $ spec $ mode != " classification" ) {
220
- rlang :: abort(" For probability predictions, the object should be a classification model." )
229
+ cli :: cli_abort(
230
+ " For probability predictions, the object should be a classification model." ,
231
+ call = call
232
+ )
221
233
},
222
234
" time" = if (object $ spec $ mode != " censored regression" ) {
223
- rlang :: abort(" For event time predictions, the object should be a censored regression." )
235
+ cli :: cli_abort(
236
+ " For event time predictions, the object should be a censored regression." ,
237
+ call = call
238
+ )
224
239
},
225
240
" survival" = if (object $ spec $ mode != " censored regression" ) {
226
- rlang :: abort(" For survival probability predictions, the object should be a censored regression." )
241
+ cli :: cli_abort(
242
+ " For survival probability predictions, the object should be a censored regression." ,
243
+ call = call
244
+ )
227
245
},
228
246
" hazard" = if (object $ spec $ mode != " censored regression" ) {
229
- rlang :: abort(" For hazard predictions, the object should be a censored regression." )
247
+ cli :: cli_abort(
248
+ " For hazard predictions, the object should be a censored regression." ,
249
+ call = call
250
+ )
230
251
},
231
252
" linear_pred" = if (object $ spec $ mode != " censored regression" ) {
232
- rlang :: abort(" For the linear predictor, the object should be a censored regression." )
253
+ cli :: cli_abort(
254
+ " For the linear predictor, the object should be a censored regression." ,
255
+ call = call
256
+ )
233
257
}
234
258
)
235
259
@@ -349,56 +373,57 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
349
373
350
374
other_args <- c(" interval" , " level" , " std_error" , " quantile" ,
351
375
" time" , " eval_time" , " increasing" )
376
+
377
+ eval_time_types <- c(" survival" , " hazard" )
378
+
352
379
is_pred_arg <- names(the_dots ) %in% other_args
353
380
if (any(! is_pred_arg )) {
354
381
bad_args <- names(the_dots )[! is_pred_arg ]
355
382
bad_args <- paste0(" `" , bad_args , " `" , collapse = " , " )
356
- rlang :: abort(
357
- glue :: glue(
358
- " The ellipses are not used to pass args to the model function's " ,
359
- " predict function. These arguments cannot be used: {bad_args}" ,
360
- )
383
+ cli :: cli_abort(
384
+ " The ellipses are not used to pass args to the model function's
385
+ predict function. These arguments cannot be used: {.val bad_args}" ,
386
+ call = call
361
387
)
362
388
}
363
389
364
390
# ----------------------------------------------------------------------------
365
391
# places where eval_time should not be given
366
392
if (any(nms == " eval_time" ) & ! type %in% c(" survival" , " hazard" )) {
367
- rlang :: abort(
368
- paste(
369
- " `eval_time` should only be passed to `predict()` when `type` is one of:" ,
370
- paste0(" '" , c(" survival" , " hazard" ), " '" , collapse = " , " )
371
- )
372
- )
393
+ cli :: cli_abort(
394
+ " {.arg eval_time} should only be passed to {.fn predict} when \\
395
+ {.arg type} is one of {.or {.val {eval_time_types}}}." ,
396
+ call = call
397
+ )
398
+
399
+
373
400
}
374
401
if (any(nms == " time" ) & ! type %in% c(" survival" , " hazard" )) {
375
- rlang :: abort(
376
- paste(
377
- " 'time' should only be passed to `predict()` when 'type' is one of:" ,
378
- paste0(" '" , c(" survival" , " hazard" ), " '" , collapse = " , " )
379
- )
402
+ cli :: cli_abort(
403
+ " {.arg time} should only be passed to {.fn predict} when {.arg type} is
404
+ one of {.or {.val {eval_time_types}}}." ,
405
+ call = call
380
406
)
381
407
}
382
408
# when eval_time should be passed
383
409
if (! any(nms %in% c(" eval_time" , " time" )) & type %in% c(" survival" , " hazard" )) {
384
- rlang :: abort(
385
- paste(
386
- " When using `type` values of 'survival' or 'hazard'," ,
387
- " a numeric vector `eval_time` should also be given."
388
- )
389
- )
410
+ cli :: cli_abort(
411
+ " When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric
412
+ vector {.arg eval_time} should also be given." ,
413
+ call = call
414
+ )
390
415
}
391
416
392
417
# `increasing` only applies to linear_pred for censored regression
393
418
if (any(nms == " increasing" ) &
394
419
! (type == " linear_pred" &
395
420
object $ spec $ mode == " censored regression" )) {
396
- rlang :: abort(
397
- paste(
398
- " The 'increasing' argument only applies to predictions of" ,
399
- " type 'linear_pred' for the mode censored regression."
400
- )
421
+ cli :: cli_abort(
422
+ " {.arg increasing} only applies to predictions of
423
+ type 'linear_pred' for the mode censored regression." ,
424
+ call = call
401
425
)
426
+
402
427
}
403
428
404
429
invisible (TRUE )
0 commit comments