Skip to content

Commit 3ade142

Browse files
committed
add predict_hazard()
1 parent 93926eb commit 3ade142

File tree

5 files changed

+67
-3
lines changed

5 files changed

+67
-3
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ S3method(predict_classprob,"_lognet")
3131
S3method(predict_classprob,"_multnet")
3232
S3method(predict_classprob,model_fit)
3333
S3method(predict_confint,model_fit)
34+
S3method(predict_hazard,model_fit)
3435
S3method(predict_linear_pred,model_fit)
3536
S3method(predict_numeric,"_elnet")
3637
S3method(predict_numeric,model_fit)
@@ -153,6 +154,7 @@ export(predict.model_fit)
153154
export(predict_class.model_fit)
154155
export(predict_classprob.model_fit)
155156
export(predict_confint.model_fit)
157+
export(predict_hazard.model_fit)
156158
export(predict_linear_pred)
157159
export(predict_linear_pred.model_fit)
158160
export(predict_numeric)

R/aaa_models.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ parsnip$modes <- c("regression", "classification", "unknown")
3232

3333
pred_types <-
3434
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
35-
"time", "survival", "linear_pred")
35+
"time", "survival", "linear_pred", "hazard")
3636

3737
# ------------------------------------------------------------------------------
3838

R/predict.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
147147
time = predict_time(object = object, new_data = new_data, ...),
148148
survival = predict_survival(object = object, new_data = new_data, ...),
149149
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
150+
hazard = predict_hazard(object = object, new_data = new_data, ...),
150151
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
151152
rlang::abort(glue::glue("I don't know about type = '{type}'"))
152153
)
@@ -158,6 +159,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
158159
prob = format_classprobs(res),
159160
time = format_time(res),
160161
survival = format_survival(res),
162+
hazard = format_hazard(res),
161163
linear_pred = format_linear_pred(res),
162164
res
163165
)
@@ -265,6 +267,19 @@ format_linear_pred <- function(x) {
265267
x
266268
}
267269

270+
format_hazard <- function(x) {
271+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
272+
x <- as_tibble(x, .name_repair = "minimal")
273+
if (!any(grepl("^\\.time", names(x)))) {
274+
names(x) <- paste0(".time_", names(x))
275+
}
276+
} else {
277+
x <- tibble(.pred_hazard = unname(x))
278+
}
279+
280+
x
281+
}
282+
268283
make_pred_call <- function(x) {
269284
if ("pkg" %in% names(x$func))
270285
cl <-

R/predict_hazard.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @inheritParams predict.model_fit
4+
#' @method predict_hazard model_fit
5+
#' @export predict_hazard.model_fit
6+
#' @export
7+
predict_hazard.model_fit <-
8+
function(object, new_data, .time, ...) {
9+
10+
if (is.null(object$spec$method$pred$hazard))
11+
rlang::abort("No hazard prediction method defined for this engine.")
12+
13+
if (inherits(object$fit, "try-error")) {
14+
rlang::warn("Model fit failed; cannot make predictions.")
15+
return(NULL)
16+
}
17+
18+
new_data <- prepare_data(object, new_data)
19+
20+
# preprocess data
21+
if (!is.null(object$spec$method$pred$hazard$pre))
22+
new_data <- object$spec$method$pred$hazard$pre(new_data, object)
23+
24+
# Pass some extra arguments to be used in post-processor
25+
object$spec$method$pred$hazard$args$.time <- .time
26+
pred_call <- make_pred_call(object$spec$method$pred$hazard)
27+
28+
res <- eval_tidy(pred_call)
29+
30+
# post-process the predictions
31+
if(!is.null(object$spec$method$pred$hazard$post)) {
32+
res <- object$spec$method$pred$hazard$post(res, object)
33+
}
34+
35+
res
36+
}
37+
38+
# @export
39+
# @keywords internal
40+
# @rdname other_predict
41+
# @inheritParams predict.model_fit
42+
predict_hazard <- function (object, ...)
43+
UseMethod("predict_hazard")

man/other_predict.Rd

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)