Skip to content

Commit 2bbb111

Browse files
authored
Merge pull request #922 from tidymodels/formatting-functions
`format_<prediction type>()` functions only format when necessary
2 parents ba1adf7 + 838f709 commit 2bbb111

File tree

1 file changed

+23
-44
lines changed

1 file changed

+23
-44
lines changed

R/predict.R

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -222,28 +222,19 @@ check_pred_type <- function(object, type, ...) {
222222
#' @export
223223

224224
format_num <- function(x) {
225-
if (inherits(x, "tbl_spark"))
225+
if (inherits(x, "tbl_spark")) {
226226
return(x)
227-
228-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
229-
x <- as_tibble(x, .name_repair = "minimal")
230-
if (!any(grepl("^\\.pred", names(x)))) {
231-
names(x) <- paste0(".pred_", names(x))
232-
}
233-
} else {
234-
x <- tibble(.pred = unname(x))
235227
}
236-
237-
x
228+
ensure_parsnip_format(x, ".pred", overwrite = FALSE)
238229
}
239230

240231
#' @rdname format-internals
241232
#' @export
242233
format_class <- function(x) {
243-
if (inherits(x, "tbl_spark"))
234+
if (inherits(x, "tbl_spark")) {
244235
return(x)
245-
246-
tibble(.pred_class = unname(x))
236+
}
237+
ensure_parsnip_format(x, ".pred_class")
247238
}
248239

249240
#' @rdname format-internals
@@ -260,57 +251,45 @@ format_classprobs <- function(x) {
260251
#' @rdname format-internals
261252
#' @export
262253
format_time <- function(x) {
263-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
264-
x <- as_tibble(x, .name_repair = "minimal")
265-
if (!any(grepl("^\\.pred_time", names(x)))) {
266-
names(x) <- paste0(".pred_time_", names(x))
267-
}
268-
} else {
269-
x <- tibble(.pred_time = unname(x))
270-
}
271-
272-
x
254+
ensure_parsnip_format(x, ".pred_time", overwrite = FALSE)
273255
}
274256

275257
#' @rdname format-internals
276258
#' @export
277259
format_survival <- function(x) {
278-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
279-
x <- as_tibble(x, .name_repair = "minimal")
280-
names(x) <- ".pred"
281-
} else {
282-
x <- tibble(.pred = unname(x))
283-
}
284-
285-
x
260+
ensure_parsnip_format(x, ".pred")
286261
}
287262

288263
#' @rdname format-internals
289264
#' @export
290265
format_linear_pred <- function(x) {
291-
if (inherits(x, "tbl_spark"))
266+
if (inherits(x, "tbl_spark")){
292267
return(x)
293-
294-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
295-
x <- as_tibble(x, .name_repair = "minimal")
296-
names(x) <- ".pred_linear_pred"
297-
} else {
298-
x <- tibble(.pred_linear_pred = unname(x))
299268
}
300-
301-
x
269+
ensure_parsnip_format(x, ".pred_linear_pred")
302270
}
303271

304272
#' @rdname format-internals
305273
#' @export
306274
format_hazard <- function(x) {
275+
ensure_parsnip_format(x, ".pred")
276+
}
277+
278+
ensure_parsnip_format <- function(x, col_name, overwrite = TRUE) {
307279
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
308280
x <- as_tibble(x, .name_repair = "minimal")
309-
names(x) <- ".pred"
281+
if (!any(grepl(paste0("^\\", col_name), names(x)))) {
282+
if (overwrite) {
283+
names(x) <- col_name
284+
} else {
285+
names(x) <- paste(col_name, names(x), sep = "_")
286+
}
287+
}
310288
} else {
311-
x <- tibble(.pred_hazard = unname(x))
289+
x <- tibble(unname(x))
290+
names(x) <- col_name
291+
x
312292
}
313-
314293
x
315294
}
316295

0 commit comments

Comments
 (0)