Skip to content

WIP: adding new prediction types for Survnip #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_linear_pred,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(predict_survival,model_fit)
S3method(predict_time,model_fit)
S3method(print,boost_tree)
S3method(print,control_parsnip)
S3method(print,decision_tree)
Expand Down Expand Up @@ -145,11 +148,16 @@ export(predict.model_fit)
export(predict_class.model_fit)
export(predict_classprob.model_fit)
export(predict_confint.model_fit)
export(predict_linear_pred)
export(predict_linear_pred.model_fit)
export(predict_numeric)
export(predict_numeric.model_fit)
export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(predict_survival.model_fit)
export(predict_time)
export(predict_time.model_fit)
export(rand_forest)
export(repair_call)
export(req_pkgs)
Expand Down
3 changes: 2 additions & 1 deletion R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
# ------------------------------------------------------------------------------

pred_types <-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
"time", "survival", "linear_pred")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the infrastructure for a "hazard" type?


# ------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ check_empty_ellipse <- function (...) {
terms
}

all_modes <- c("classification", "regression")
all_modes <- c("classification", "regression", "censored regression")


deparserizer <- function(x, limit = options()$width - 10) {
Expand Down
77 changes: 66 additions & 11 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
check_installs(object$spec)
load_libs(object$spec, quiet = TRUE)

other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
other_args <- c("level", "std_error", "quantile", ".time")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
Expand All @@ -138,21 +138,27 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
res <- switch(
type,
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
time = predict_time(object = object, new_data = new_data, ...),
survival = predict_survival(object = object, new_data = new_data, ...),
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
type,
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
time = format_time(res),
survival = format_survival(res),
linear_pred = format_linear_pred(res),
res
)
}
Expand All @@ -166,6 +172,7 @@ check_pred_type <- function(object, type) {
switch(object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
rlang::abort("`type` should be 'regression' or 'classification'."))
}
if (!(type %in% pred_types))
Expand Down Expand Up @@ -216,6 +223,54 @@ format_classprobs <- function(x) {
x
}

format_time <- function(x) {
if (inherits(x, "tbl_spark"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the spark stuff from the survival bits

return(x)

if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_time = unname(x))
}

x
}

format_survival <- function(x) {
if (inherits(x, "tbl_spark"))
return(x)

if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_survival = unname(x))
}

x
}

format_linear_pred <- function(x) {
if (inherits(x, "tbl_spark"))
return(x)

if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_linear_pred = unname(x))
}

x
}

make_pred_call <- function(x) {
if ("pkg" %in% names(x$func))
cl <-
Expand Down
52 changes: 52 additions & 0 deletions R/predict_linear_pred.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict.model_fit
#' @method predict_linear_pred model_fit
#' @export predict_linear_pred.model_fit
#' @export
predict_linear_pred.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "censored regression")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove this check. We can do these types of predictions for all types of models.

rlang::abort(glue::glue("`predict_linear_pred()` is for predicting linear predictors. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))

if (!any(names(object$spec$method$pred) == "linear_pred"))
rlang::abort("No prediction module defined for this model.")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$linear_pred$pre))
new_data <- object$spec$method$pred$linear_pred$pre(new_data, object)

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$linear_pred)

res <- eval_tidy(pred_call)
# post-process the predictions

if (!is.null(object$spec$method$pred$linear_pred$post)) {
res <- object$spec$method$pred$linear_pred$post(res, object)
}

if (is.vector(res)) {
res <- unname(res)
} else {
if (!inherits(res, "tbl_spark"))
res <- as.data.frame(res)
}
res
}


#' @export
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_linear_pred.model_fit
predict_linear_pred <- function(object, ...)
UseMethod("predict_linear_pred")
49 changes: 49 additions & 0 deletions R/predict_survival.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict.model_fit
#' @method predict_survival model_fit
#' @export predict_survival.model_fit
#' @export
predict_survival.model_fit <-
function(object, new_data, .time, ...) {
if (object$spec$mode != "censored regression")
rlang::abort(glue::glue(
"`predict_survival()` is for predicting survival probabilities. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."
))

if (is.null(object$spec$method$pred$survival))
rlang::abort("No survival prediction method defined for this engine.")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$survival$pre))
new_data <- object$spec$method$pred$survival$pre(new_data, object)

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$survival$args$.time <- .time
pred_call <- make_pred_call(object$spec$method$pred$survival)

res <- eval_tidy(pred_call)

# post-process the predictions
if(!is.null(object$spec$method$pred$survival$post)) {
res <- object$spec$method$pred$survival$post(res, object)
}

res
}

# @export
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_survival <- function (object, ...)
UseMethod("predict_survival")
52 changes: 52 additions & 0 deletions R/predict_time.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict.model_fit
#' @method predict_time model_fit
#' @export predict_time.model_fit
#' @export
predict_time.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "censored regression")
rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))

if (!any(names(object$spec$method$pred) == "time"))
rlang::abort("No prediction module defined for this model.")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$time$pre))
new_data <- object$spec$method$pred$time$pre(new_data, object)

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$time)

res <- eval_tidy(pred_call)
# post-process the predictions

if (!is.null(object$spec$method$pred$time$post)) {
res <- object$spec$method$pred$time$post(res, object)
}

if (is.vector(res)) {
res <- unname(res)
} else {
if (!inherits(res, "tbl_spark"))
res <- as.data.frame(res)
}
res
}


#' @export
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_time.model_fit
predict_time <- function(object, ...)
UseMethod("predict_time")
Loading