Skip to content

Commit 838f709

Browse files
committed
add format helper
1 parent f03548b commit 838f709

File tree

1 file changed

+22
-59
lines changed

1 file changed

+22
-59
lines changed

R/predict.R

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -222,37 +222,19 @@ check_pred_type <- function(object, type, ...) {
222222
#' @export
223223

224224
format_num <- function(x) {
225-
if (inherits(x, "tbl_spark"))
225+
if (inherits(x, "tbl_spark")) {
226226
return(x)
227-
228-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
229-
x <- as_tibble(x, .name_repair = "minimal")
230-
if (!any(grepl("^\\.pred", names(x)))) {
231-
names(x) <- paste0(".pred_", names(x))
232-
}
233-
} else {
234-
x <- tibble(.pred = unname(x))
235227
}
236-
237-
x
228+
ensure_parsnip_format(x, ".pred", overwrite = FALSE)
238229
}
239230

240231
#' @rdname format-internals
241232
#' @export
242233
format_class <- function(x) {
243-
if (inherits(x, "tbl_spark"))
234+
if (inherits(x, "tbl_spark")) {
244235
return(x)
245-
246-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
247-
x <- as_tibble(x, .name_repair = "minimal")
248-
if (!any(grepl("^\\.pred_class", names(x)))) {
249-
names(x) <- ".pred_class"
250-
}
251-
} else {
252-
x <- tibble(.pred_class = unname(x))
253236
}
254-
255-
x
237+
ensure_parsnip_format(x, ".pred_class")
256238
}
257239

258240
#' @rdname format-internals
@@ -269,64 +251,45 @@ format_classprobs <- function(x) {
269251
#' @rdname format-internals
270252
#' @export
271253
format_time <- function(x) {
272-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
273-
x <- as_tibble(x, .name_repair = "minimal")
274-
if (!any(grepl("^\\.pred_time", names(x)))) {
275-
names(x) <- paste0(".pred_time_", names(x))
276-
}
277-
} else {
278-
x <- tibble(.pred_time = unname(x))
279-
}
280-
281-
x
254+
ensure_parsnip_format(x, ".pred_time", overwrite = FALSE)
282255
}
283256

284257
#' @rdname format-internals
285258
#' @export
286259
format_survival <- function(x) {
287-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
288-
x <- as_tibble(x, .name_repair = "minimal")
289-
if (!any(grepl("^\\.pred", names(x)))) {
290-
names(x) <- ".pred"
291-
}
292-
} else {
293-
x <- tibble(.pred = unname(x))
294-
}
295-
296-
x
260+
ensure_parsnip_format(x, ".pred")
297261
}
298262

299263
#' @rdname format-internals
300264
#' @export
301265
format_linear_pred <- function(x) {
302-
if (inherits(x, "tbl_spark"))
266+
if (inherits(x, "tbl_spark")){
303267
return(x)
304-
305-
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
306-
x <- as_tibble(x, .name_repair = "minimal")
307-
if (!any(grepl("^\\.pred_linear_pred", names(x)))) {
308-
names(x) <- ".pred_linear_pred"
309-
}
310-
} else {
311-
x <- tibble(.pred_linear_pred = unname(x))
312268
}
313-
314-
x
269+
ensure_parsnip_format(x, ".pred_linear_pred")
315270
}
316271

317272
#' @rdname format-internals
318273
#' @export
319274
format_hazard <- function(x) {
275+
ensure_parsnip_format(x, ".pred")
276+
}
277+
278+
ensure_parsnip_format <- function(x, col_name, overwrite = TRUE) {
320279
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
321280
x <- as_tibble(x, .name_repair = "minimal")
322-
if (!any(grepl("^\\.pred", names(x)))) {
323-
names(x) <- ".pred"
281+
if (!any(grepl(paste0("^\\", col_name), names(x)))) {
282+
if (overwrite) {
283+
names(x) <- col_name
284+
} else {
285+
names(x) <- paste(col_name, names(x), sep = "_")
286+
}
324287
}
288+
} else {
289+
x <- tibble(unname(x))
290+
names(x) <- col_name
291+
x
325292
}
326-
else {
327-
x <- tibble(.pred_hazard = unname(x))
328-
}
329-
330293
x
331294
}
332295

0 commit comments

Comments
 (0)