Skip to content

time to eval_time #936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9004
Version: 1.0.4.9005
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
29 changes: 19 additions & 10 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#' linear predictors). Default value is `FALSE`.
#' \item `quantile`: for `type` equal to `quantile`, the quantiles of the
#' distribution. Default is `(1:9)/10`.
#' \item `time`: for `type` equal to `"survival"` or `"hazard"`, the
#' \item `eval_time`: for `type` equal to `"survival"` or `"hazard"`, the
#' time points at which the survival probability or hazard is estimated.
#' }
#' @details For `type = NULL`, `predict()` uses
Expand All @@ -48,7 +48,7 @@
#'
#' ## Censored regression predictions
#'
#' For censored regression, a numeric vector for `time` is required when
#' For censored regression, a numeric vector for `eval_time` is required when
#' survival or hazard probabilities are requested. Also, when
#' `type = "linear_pred"`, censored regression models will by default be
#' formatted such that the linear predictor _increases_ with time. This may
Expand Down Expand Up @@ -83,11 +83,11 @@
#'
#' For `type = "survival"`, the tibble has a `.pred` column, which is
#' a list-column. Each list element contains a tibble with columns
#' `.time` and `.pred_survival` (and perhaps other columns).
#' `.eval_time` and `.pred_survival` (and perhaps other columns).
#'
#' For `type = "hazard"`, the tibble has a `.pred` column, which is
#' a list-column. Each list element contains a tibble with columns
#' `.time` and `.pred_hazard` (and perhaps other columns).
#' `.eval_time` and `.pred_hazard` (and perhaps other columns).
#'
#' Using `type = "raw"` with `predict.model_fit()` will return
#' the unadulterated results of the prediction function.
Expand Down Expand Up @@ -328,7 +328,8 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile", "time", "increasing")
other_args <- c("interval", "level", "std_error", "quantile",
"time", "eval_time", "increasing")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
Expand All @@ -342,7 +343,15 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
}

# ----------------------------------------------------------------------------
# places where time should not be given
# places where eval_time should not be given
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"`eval_time` should only be passed to `predict()` when `type` is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
}
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
Expand All @@ -351,12 +360,12 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
)
)
}
# when time should be passed
if (!any(nms == "time") & type %in% c("survival", "hazard")) {
# when eval_time should be passed
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using 'type' values of 'survival' or 'hazard' are given,",
"a numeric vector 'time' should also be given."
"When using `type` values of 'survival' or 'hazard',",
"a numeric vector `eval_time` should also be given."
)
)
}
Expand Down
53 changes: 32 additions & 21 deletions R/predict_hazard.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,47 @@
#' @method predict_hazard model_fit
#' @export predict_hazard.model_fit
#' @export
predict_hazard.model_fit <-
function(object, new_data, time, ...) {

check_spec_pred_type(object, "hazard")
predict_hazard.model_fit <- function(object,
new_data,
eval_time,
time = deprecated(),
...) {
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"1.0.4.9005",
"predict_hazard(time)",
"predict_hazard(eval_time)"
)
eval_time <- time
}

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}
check_spec_pred_type(object, "hazard")

new_data <- prepare_data(object, new_data)
if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

# preprocess data
if (!is.null(object$spec$method$pred$hazard$pre))
new_data <- object$spec$method$pred$hazard$pre(new_data, object)
new_data <- prepare_data(object, new_data)

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$hazard$args$time <- time
pred_call <- make_pred_call(object$spec$method$pred$hazard)
# preprocess data
if (!is.null(object$spec$method$pred$hazard$pre))
new_data <- object$spec$method$pred$hazard$pre(new_data, object)

res <- eval_tidy(pred_call)
# Pass some extra arguments to be used in post-processor
object$spec$method$pred$hazard$args$eval_time <- eval_time
pred_call <- make_pred_call(object$spec$method$pred$hazard)

# post-process the predictions
if(!is.null(object$spec$method$pred$hazard$post)) {
res <- object$spec$method$pred$hazard$post(res, object)
}
res <- eval_tidy(pred_call)

res
# post-process the predictions
if(!is.null(object$spec$method$pred$hazard$post)) {
res <- object$spec$method$pred$hazard$post(res, object)
}

res
}

# @export
# @keywords internal
# @rdname other_predict
Expand Down
55 changes: 34 additions & 21 deletions R/predict_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,49 @@
#' @method predict_survival model_fit
#' @export predict_survival.model_fit
#' @export
predict_survival.model_fit <-
function(object, new_data, time, interval = "none", level = 0.95, ...) {

check_spec_pred_type(object, "survival")
predict_survival.model_fit <- function(object,
new_data,
eval_time,
time = deprecated(),
interval = "none",
level = 0.95,
...) {
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"1.0.4.9005",
"predict_survival(time)",
"predict_survival(eval_time)"
)
eval_time <- time
}

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}
check_spec_pred_type(object, "survival")

new_data <- prepare_data(object, new_data)
if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

# preprocess data
if (!is.null(object$spec$method$pred$survival$pre))
new_data <- object$spec$method$pred$survival$pre(new_data, object)
new_data <- prepare_data(object, new_data)

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$survival$args$time <- time
pred_call <- make_pred_call(object$spec$method$pred$survival)
# preprocess data
if (!is.null(object$spec$method$pred$survival$pre))
new_data <- object$spec$method$pred$survival$pre(new_data, object)

res <- eval_tidy(pred_call)
# Pass some extra arguments to be used in post-processor
object$spec$method$pred$survival$args$eval_time <- eval_time
pred_call <- make_pred_call(object$spec$method$pred$survival)

# post-process the predictions
if(!is.null(object$spec$method$pred$survival$post)) {
res <- object$spec$method$pred$survival$post(res, object)
}
res <- eval_tidy(pred_call)

res
# post-process the predictions
if(!is.null(object$spec$method$pred$survival$post)) {
res <- object$spec$method$pred$survival$post(res, object)
}

res
}

#' @export
#' @keywords internal
#' @rdname other_predict
Expand Down
14 changes: 11 additions & 3 deletions man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/predict.model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.