@@ -222,28 +222,19 @@ check_pred_type <- function(object, type, ...) {
222
222
# ' @export
223
223
224
224
format_num <- function (x ) {
225
- if (inherits(x , " tbl_spark" ))
225
+ if (inherits(x , " tbl_spark" )) {
226
226
return (x )
227
-
228
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
229
- x <- as_tibble(x , .name_repair = " minimal" )
230
- if (! any(grepl(" ^\\ .pred" , names(x )))) {
231
- names(x ) <- paste0(" .pred_" , names(x ))
232
- }
233
- } else {
234
- x <- tibble(.pred = unname(x ))
235
227
}
236
-
237
- x
228
+ ensure_parsnip_format(x , " .pred" , overwrite = FALSE )
238
229
}
239
230
240
231
# ' @rdname format-internals
241
232
# ' @export
242
233
format_class <- function (x ) {
243
- if (inherits(x , " tbl_spark" ))
234
+ if (inherits(x , " tbl_spark" )) {
244
235
return (x )
245
-
246
- tibble( .pred_class = unname( x ) )
236
+ }
237
+ ensure_parsnip_format( x , " .pred_class" )
247
238
}
248
239
249
240
# ' @rdname format-internals
@@ -260,57 +251,45 @@ format_classprobs <- function(x) {
260
251
# ' @rdname format-internals
261
252
# ' @export
262
253
format_time <- function (x ) {
263
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
264
- x <- as_tibble(x , .name_repair = " minimal" )
265
- if (! any(grepl(" ^\\ .pred_time" , names(x )))) {
266
- names(x ) <- paste0(" .pred_time_" , names(x ))
267
- }
268
- } else {
269
- x <- tibble(.pred_time = unname(x ))
270
- }
271
-
272
- x
254
+ ensure_parsnip_format(x , " .pred_time" , overwrite = FALSE )
273
255
}
274
256
275
257
# ' @rdname format-internals
276
258
# ' @export
277
259
format_survival <- function (x ) {
278
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
279
- x <- as_tibble(x , .name_repair = " minimal" )
280
- names(x ) <- " .pred"
281
- } else {
282
- x <- tibble(.pred = unname(x ))
283
- }
284
-
285
- x
260
+ ensure_parsnip_format(x , " .pred" )
286
261
}
287
262
288
263
# ' @rdname format-internals
289
264
# ' @export
290
265
format_linear_pred <- function (x ) {
291
- if (inherits(x , " tbl_spark" ))
266
+ if (inherits(x , " tbl_spark" )){
292
267
return (x )
293
-
294
- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
295
- x <- as_tibble(x , .name_repair = " minimal" )
296
- names(x ) <- " .pred_linear_pred"
297
- } else {
298
- x <- tibble(.pred_linear_pred = unname(x ))
299
268
}
300
-
301
- x
269
+ ensure_parsnip_format(x , " .pred_linear_pred" )
302
270
}
303
271
304
272
# ' @rdname format-internals
305
273
# ' @export
306
274
format_hazard <- function (x ) {
275
+ ensure_parsnip_format(x , " .pred" )
276
+ }
277
+
278
+ ensure_parsnip_format <- function (x , col_name , overwrite = TRUE ) {
307
279
if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
308
280
x <- as_tibble(x , .name_repair = " minimal" )
309
- names(x ) <- " .pred"
281
+ if (! any(grepl(paste0(" ^\\ " , col_name ), names(x )))) {
282
+ if (overwrite ) {
283
+ names(x ) <- col_name
284
+ } else {
285
+ names(x ) <- paste(col_name , names(x ), sep = " _" )
286
+ }
287
+ }
310
288
} else {
311
- x <- tibble(.pred_hazard = unname(x ))
289
+ x <- tibble(unname(x ))
290
+ names(x ) <- col_name
291
+ x
312
292
}
313
-
314
293
x
315
294
}
316
295
0 commit comments