Skip to content

Commit fcb1071

Browse files
committed
tryCatch for class probabilities in augment()
1 parent bc125e9 commit fcb1071

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

R/augment.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,17 @@ augment.model_fit <- function(x, new_data, ...) {
6868
}
6969
}
7070
} else if (x$spec$mode == "classification") {
71-
new_data <-
72-
new_data %>%
73-
dplyr::bind_cols(
74-
predict(x, new_data = new_data, type = "class"),
71+
new_data <- dplyr::bind_cols(
72+
new_data,
73+
predict(x, new_data = new_data, type = "class")
74+
)
75+
tryCatch(
76+
new_data <- dplyr::bind_cols(
77+
new_data,
7578
predict(x, new_data = new_data, type = "prob")
76-
)
79+
),
80+
error = function(cnd) cnd
81+
)
7782
} else {
7883
rlang::abort(paste("Unknown mode:", x$spec$mode))
7984
}

0 commit comments

Comments
 (0)