@@ -222,37 +222,19 @@ check_pred_type <- function(object, type, ...) {
222
222
# ' @export
223
223
224
224
format_num <- function (x ) {
225
- if (inherits(x , " tbl_spark" ))
225
+ if (inherits(x , " tbl_spark" )) {
226
226
return (x )
227
-
228
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
229
- x <- as_tibble(x , .name_repair = " minimal" )
230
- if (! any(grepl(" ^\\ .pred" , names(x )))) {
231
- names(x ) <- paste0(" .pred_" , names(x ))
232
- }
233
- } else {
234
- x <- tibble(.pred = unname(x ))
235
227
}
236
-
237
- x
228
+ ensure_parsnip_format(x , " .pred" , overwrite = FALSE )
238
229
}
239
230
240
231
# ' @rdname format-internals
241
232
# ' @export
242
233
format_class <- function (x ) {
243
- if (inherits(x , " tbl_spark" ))
234
+ if (inherits(x , " tbl_spark" )) {
244
235
return (x )
245
-
246
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
247
- x <- as_tibble(x , .name_repair = " minimal" )
248
- if (! any(grepl(" ^\\ .pred_class" , names(x )))) {
249
- names(x ) <- " .pred_class"
250
- }
251
- } else {
252
- x <- tibble(.pred_class = unname(x ))
253
236
}
254
-
255
- x
237
+ ensure_parsnip_format(x , " .pred_class" )
256
238
}
257
239
258
240
# ' @rdname format-internals
@@ -269,64 +251,45 @@ format_classprobs <- function(x) {
269
251
# ' @rdname format-internals
270
252
# ' @export
271
253
format_time <- function (x ) {
272
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
273
- x <- as_tibble(x , .name_repair = " minimal" )
274
- if (! any(grepl(" ^\\ .pred_time" , names(x )))) {
275
- names(x ) <- paste0(" .pred_time_" , names(x ))
276
- }
277
- } else {
278
- x <- tibble(.pred_time = unname(x ))
279
- }
280
-
281
- x
254
+ ensure_parsnip_format(x , " .pred_time" , overwrite = FALSE )
282
255
}
283
256
284
257
# ' @rdname format-internals
285
258
# ' @export
286
259
format_survival <- function (x ) {
287
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
288
- x <- as_tibble(x , .name_repair = " minimal" )
289
- if (! any(grepl(" ^\\ .pred" , names(x )))) {
290
- names(x ) <- " .pred"
291
- }
292
- } else {
293
- x <- tibble(.pred = unname(x ))
294
- }
295
-
296
- x
260
+ ensure_parsnip_format(x , " .pred" )
297
261
}
298
262
299
263
# ' @rdname format-internals
300
264
# ' @export
301
265
format_linear_pred <- function (x ) {
302
- if (inherits(x , " tbl_spark" ))
266
+ if (inherits(x , " tbl_spark" )){
303
267
return (x )
304
-
305
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
306
- x <- as_tibble(x , .name_repair = " minimal" )
307
- if (! any(grepl(" ^\\ .pred_linear_pred" , names(x )))) {
308
- names(x ) <- " .pred_linear_pred"
309
- }
310
- } else {
311
- x <- tibble(.pred_linear_pred = unname(x ))
312
268
}
313
-
314
- x
269
+ ensure_parsnip_format(x , " .pred_linear_pred" )
315
270
}
316
271
317
272
# ' @rdname format-internals
318
273
# ' @export
319
274
format_hazard <- function (x ) {
275
+ ensure_parsnip_format(x , " .pred" )
276
+ }
277
+
278
+ ensure_parsnip_format <- function (x , col_name , overwrite = TRUE ) {
320
279
if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
321
280
x <- as_tibble(x , .name_repair = " minimal" )
322
- if (! any(grepl(" ^\\ .pred" , names(x )))) {
323
- names(x ) <- " .pred"
281
+ if (! any(grepl(paste0(" ^\\ " , col_name ), names(x )))) {
282
+ if (overwrite ) {
283
+ names(x ) <- col_name
284
+ } else {
285
+ names(x ) <- paste(col_name , names(x ), sep = " _" )
286
+ }
324
287
}
288
+ } else {
289
+ x <- tibble(unname(x ))
290
+ names(x ) <- col_name
291
+ x
325
292
}
326
- else {
327
- x <- tibble(.pred_hazard = unname(x ))
328
- }
329
-
330
293
x
331
294
}
332
295
0 commit comments