Skip to content

Commit c0b78d1

Browse files
committed
transfer over PR tidymodels#359 due to deletion of branch
1 parent 97b576e commit c0b78d1

File tree

8 files changed

+231
-14
lines changed

8 files changed

+231
-14
lines changed

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ S3method(predict_classprob,"_lognet")
3131
S3method(predict_classprob,"_multnet")
3232
S3method(predict_classprob,model_fit)
3333
S3method(predict_confint,model_fit)
34+
S3method(predict_linear_pred,model_fit)
3435
S3method(predict_numeric,"_elnet")
3536
S3method(predict_numeric,model_fit)
3637
S3method(predict_quantile,model_fit)
3738
S3method(predict_raw,"_elnet")
3839
S3method(predict_raw,"_lognet")
3940
S3method(predict_raw,"_multnet")
4041
S3method(predict_raw,model_fit)
42+
S3method(predict_survival,model_fit)
43+
S3method(predict_time,model_fit)
4144
S3method(print,boost_tree)
4245
S3method(print,control_parsnip)
4346
S3method(print,decision_tree)
@@ -150,11 +153,16 @@ export(predict.model_fit)
150153
export(predict_class.model_fit)
151154
export(predict_classprob.model_fit)
152155
export(predict_confint.model_fit)
156+
export(predict_linear_pred)
157+
export(predict_linear_pred.model_fit)
153158
export(predict_numeric)
154159
export(predict_numeric.model_fit)
155160
export(predict_quantile.model_fit)
156161
export(predict_raw)
157162
export(predict_raw.model_fit)
163+
export(predict_survival.model_fit)
164+
export(predict_time)
165+
export(predict_time.model_fit)
158166
export(prepare_data)
159167
export(rand_forest)
160168
export(repair_call)

R/aaa_models.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
3131
# ------------------------------------------------------------------------------
3232

3333
pred_types <-
34-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
34+
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
35+
"time", "survival", "linear_pred")
3536

3637
# ------------------------------------------------------------------------------
3738

R/misc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ check_empty_ellipse <- function (...) {
2323
terms
2424
}
2525

26-
all_modes <- c("classification", "regression")
26+
all_modes <- c("classification", "regression", "censored regression")
2727

2828

2929
deparserizer <- function(x, limit = options()$width - 10) {

R/predict.R

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
120120
check_installs(object$spec)
121121
load_libs(object$spec, quiet = TRUE)
122122

123-
other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
123+
other_args <- c("level", "std_error", "quantile", ".time")
124124
is_pred_arg <- names(the_dots) %in% other_args
125125
if (any(!is_pred_arg)) {
126126
bad_args <- names(the_dots)[!is_pred_arg]
@@ -138,21 +138,27 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
138138
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
139139
res <- switch(
140140
type,
141-
numeric = predict_numeric(object = object, new_data = new_data, ...),
142-
class = predict_class(object = object, new_data = new_data, ...),
143-
prob = predict_classprob(object = object, new_data = new_data, ...),
144-
conf_int = predict_confint(object = object, new_data = new_data, ...),
145-
pred_int = predict_predint(object = object, new_data = new_data, ...),
146-
quantile = predict_quantile(object = object, new_data = new_data, ...),
147-
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
141+
numeric = predict_numeric(object = object, new_data = new_data, ...),
142+
class = predict_class(object = object, new_data = new_data, ...),
143+
prob = predict_classprob(object = object, new_data = new_data, ...),
144+
conf_int = predict_confint(object = object, new_data = new_data, ...),
145+
pred_int = predict_predint(object = object, new_data = new_data, ...),
146+
quantile = predict_quantile(object = object, new_data = new_data, ...),
147+
time = predict_time(object = object, new_data = new_data, ...),
148+
survival = predict_survival(object = object, new_data = new_data, ...),
149+
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
150+
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
148151
rlang::abort(glue::glue("I don't know about type = '{type}'"))
149152
)
150153
if (!inherits(res, "tbl_spark")) {
151154
res <- switch(
152155
type,
153-
numeric = format_num(res),
154-
class = format_class(res),
155-
prob = format_classprobs(res),
156+
numeric = format_num(res),
157+
class = format_class(res),
158+
prob = format_classprobs(res),
159+
time = format_time(res),
160+
survival = format_survival(res),
161+
linear_pred = format_linear_pred(res),
156162
res
157163
)
158164
}
@@ -166,6 +172,7 @@ check_pred_type <- function(object, type) {
166172
switch(object$spec$mode,
167173
regression = "numeric",
168174
classification = "class",
175+
"censored regression" = "time",
169176
rlang::abort("`type` should be 'regression' or 'classification'."))
170177
}
171178
if (!(type %in% pred_types))
@@ -216,6 +223,48 @@ format_classprobs <- function(x) {
216223
x
217224
}
218225

226+
format_time <- function(x) {
227+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
228+
x <- as_tibble(x, .name_repair = "minimal")
229+
if (!any(grepl("^\\.time", names(x)))) {
230+
names(x) <- paste0(".time_", names(x))
231+
}
232+
} else {
233+
x <- tibble(.pred_time = unname(x))
234+
}
235+
236+
x
237+
}
238+
239+
format_survival <- function(x) {
240+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
241+
x <- as_tibble(x, .name_repair = "minimal")
242+
if (!any(grepl("^\\.time", names(x)))) {
243+
names(x) <- paste0(".time_", names(x))
244+
}
245+
} else {
246+
x <- tibble(.pred_survival = unname(x))
247+
}
248+
249+
x
250+
}
251+
252+
format_linear_pred <- function(x) {
253+
if (inherits(x, "tbl_spark"))
254+
return(x)
255+
256+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
257+
x <- as_tibble(x, .name_repair = "minimal")
258+
if (!any(grepl("^\\.time", names(x)))) {
259+
names(x) <- paste0(".time_", names(x))
260+
}
261+
} else {
262+
x <- tibble(.pred_linear_pred = unname(x))
263+
}
264+
265+
x
266+
}
267+
219268
make_pred_call <- function(x) {
220269
if ("pkg" %in% names(x$func))
221270
cl <-

R/predict_linear_pred.R

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @inheritParams predict.model_fit
4+
#' @method predict_linear_pred model_fit
5+
#' @export predict_linear_pred.model_fit
6+
#' @export
7+
predict_linear_pred.model_fit <- function(object, new_data, ...) {
8+
9+
if (!any(names(object$spec$method$pred) == "linear_pred"))
10+
rlang::abort("No prediction module defined for this model.")
11+
12+
if (inherits(object$fit, "try-error")) {
13+
rlang::warn("Model fit failed; cannot make predictions.")
14+
return(NULL)
15+
}
16+
17+
new_data <- prepare_data(object, new_data)
18+
19+
# preprocess data
20+
if (!is.null(object$spec$method$pred$linear_pred$pre))
21+
new_data <- object$spec$method$pred$linear_pred$pre(new_data, object)
22+
23+
# create prediction call
24+
pred_call <- make_pred_call(object$spec$method$pred$linear_pred)
25+
26+
res <- eval_tidy(pred_call)
27+
# post-process the predictions
28+
29+
if (!is.null(object$spec$method$pred$linear_pred$post)) {
30+
res <- object$spec$method$pred$linear_pred$post(res, object)
31+
}
32+
33+
if (is.vector(res)) {
34+
res <- unname(res)
35+
} else {
36+
if (!inherits(res, "tbl_spark"))
37+
res <- as.data.frame(res)
38+
}
39+
res
40+
}
41+
42+
43+
#' @export
44+
#' @keywords internal
45+
#' @rdname other_predict
46+
#' @inheritParams predict_linear_pred.model_fit
47+
predict_linear_pred <- function(object, ...)
48+
UseMethod("predict_linear_pred")

R/predict_survival.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @inheritParams predict.model_fit
4+
#' @method predict_survival model_fit
5+
#' @export predict_survival.model_fit
6+
#' @export
7+
predict_survival.model_fit <-
8+
function(object, new_data, .time, ...) {
9+
10+
if (is.null(object$spec$method$pred$survival))
11+
rlang::abort("No survival prediction method defined for this engine.")
12+
13+
if (inherits(object$fit, "try-error")) {
14+
rlang::warn("Model fit failed; cannot make predictions.")
15+
return(NULL)
16+
}
17+
18+
new_data <- prepare_data(object, new_data)
19+
20+
# preprocess data
21+
if (!is.null(object$spec$method$pred$survival$pre))
22+
new_data <- object$spec$method$pred$survival$pre(new_data, object)
23+
24+
# Pass some extra arguments to be used in post-processor
25+
object$spec$method$pred$survival$args$.time <- .time
26+
pred_call <- make_pred_call(object$spec$method$pred$survival)
27+
28+
res <- eval_tidy(pred_call)
29+
30+
# post-process the predictions
31+
if(!is.null(object$spec$method$pred$survival$post)) {
32+
res <- object$spec$method$pred$survival$post(res, object)
33+
}
34+
35+
res
36+
}
37+
38+
# @export
39+
# @keywords internal
40+
# @rdname other_predict
41+
# @inheritParams predict.model_fit
42+
predict_survival <- function (object, ...)
43+
UseMethod("predict_survival")

R/predict_time.R

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @inheritParams predict.model_fit
4+
#' @method predict_time model_fit
5+
#' @export predict_time.model_fit
6+
#' @export
7+
predict_time.model_fit <- function(object, new_data, ...) {
8+
if (object$spec$mode != "censored regression")
9+
rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ",
10+
"Use `predict_class()` or `predict_classprob()` for ",
11+
"classification models."))
12+
13+
if (!any(names(object$spec$method$pred) == "time"))
14+
rlang::abort("No prediction module defined for this model.")
15+
16+
if (inherits(object$fit, "try-error")) {
17+
rlang::warn("Model fit failed; cannot make predictions.")
18+
return(NULL)
19+
}
20+
21+
new_data <- prepare_data(object, new_data)
22+
23+
# preprocess data
24+
if (!is.null(object$spec$method$pred$time$pre))
25+
new_data <- object$spec$method$pred$time$pre(new_data, object)
26+
27+
# create prediction call
28+
pred_call <- make_pred_call(object$spec$method$pred$time)
29+
30+
res <- eval_tidy(pred_call)
31+
# post-process the predictions
32+
33+
if (!is.null(object$spec$method$pred$time$post)) {
34+
res <- object$spec$method$pred$time$post(res, object)
35+
}
36+
37+
if (is.vector(res)) {
38+
res <- unname(res)
39+
} else {
40+
if (!inherits(res, "tbl_spark"))
41+
res <- as.data.frame(res)
42+
}
43+
res
44+
}
45+
46+
47+
#' @export
48+
#' @keywords internal
49+
#' @rdname other_predict
50+
#' @inheritParams predict_time.model_fit
51+
predict_time <- function(object, ...)
52+
UseMethod("predict_time")

man/other_predict.Rd

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)