@@ -179,8 +179,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
179
179
res
180
180
}
181
181
182
- surv_types <- c(" time" , " survival" , " hazard" )
183
-
184
182
check_pred_type <- function (object , type , ... ) {
185
183
if (is.null(type )) {
186
184
type <-
@@ -197,14 +195,31 @@ check_pred_type <- function(object, type, ...) {
197
195
glue_collapse(pred_types , sep = " , " , last = " and " )
198
196
)
199
197
)
200
- if (type == " numeric" & object $ spec $ mode != " regression" )
201
- rlang :: abort(" For numeric predictions, the object should be a regression model." )
202
- if (type == " class" & object $ spec $ mode != " classification" )
203
- rlang :: abort(" For class predictions, the object should be a classification model." )
204
- if (type == " prob" & object $ spec $ mode != " classification" )
205
- rlang :: abort(" For probability predictions, the object should be a classification model." )
206
- if (type %in% surv_types & object $ spec $ mode != " censored regression" )
207
- rlang :: abort(" For event time predictions, the object should be a censored regression." )
198
+
199
+ switch (
200
+ type ,
201
+ " numeric" = if (object $ spec $ mode != " regression" ) {
202
+ rlang :: abort(" For numeric predictions, the object should be a regression model." )
203
+ },
204
+ " class" = if (object $ spec $ mode != " classification" ) {
205
+ rlang :: abort(" For class predictions, the object should be a classification model." )
206
+ },
207
+ " prob" = if (object $ spec $ mode != " classification" ) {
208
+ rlang :: abort(" For probability predictions, the object should be a classification model." )
209
+ },
210
+ " time" = if (object $ spec $ mode != " censored regression" ) {
211
+ rlang :: abort(" For event time predictions, the object should be a censored regression." )
212
+ },
213
+ " survival" = if (object $ spec $ mode != " censored regression" ) {
214
+ rlang :: abort(" For survival probability predictions, the object should be a censored regression." )
215
+ },
216
+ " hazard" = if (object $ spec $ mode != " censored regression" ) {
217
+ rlang :: abort(" For hazard predictions, the object should be a censored regression." )
218
+ },
219
+ " linear_pred" = if (object $ spec $ mode != " censored regression" ) {
220
+ rlang :: abort(" For the linear predictor, the object should be a censored regression." )
221
+ }
222
+ )
208
223
209
224
# TODO check for ... options when not the correct type
210
225
type
0 commit comments