-
Notifications
You must be signed in to change notification settings - Fork 92
WIP: adding new prediction types for Survnip #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Looking back at my notes, I think I favor a mode of "censored regression" rather than "risk prediction". Technically, if we produce survival probabilities, those aren't really the risk. The prediction format for survival probabilities would have Here's an example using library(tidymodels)
#> ── Attaching packages ───────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.8 ✓ rsample 0.0.7
#> ✓ dplyr 1.0.1 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.1
#> ✓ infer 0.5.2 ✓ tune 0.1.1.9000
#> ✓ modeldata 0.0.2 ✓ workflows 0.1.2.9000
#> ✓ parsnip 0.1.3 ✓ yardstick 0.0.7.9000
#> ✓ purrr 0.3.4
#> ── Conflicts ──────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(flexsurv)
#> Loading required package: survival data(ovarian)
fit <- flexsurvreg(formula = Surv(futime, fustat) ~ 1, data = ovarian, dist="weibull")
summary(fit, ovarian[1:2,], t = c(50, 100, 150)) %>%
map_dfr(~ .x) %>%
mutate(row = rep(1:2, each = 3)) %>%
dplyr::select(.time = time, .pred_survivial = est, row) %>%
group_nest(row, .key = ".pred") %>%
select(-row)
#> # A tibble: 2 x 1
#> .pred
#> <list<tbl_df[,2]>>
#> 1 [3 × 2]
#> 2 [3 × 2] Created on 2020-08-04 by the reprex package (v0.3.0) |
Alright! I have changed everything to "censored regression". We are technically doing right-censored regression. So if we add left-censored regression we would need to remember to clarify. 👍 on nested tibbles. Where would be the cleanest place to convert? doing |
Changes to predict_survival should be done now library(survnip)
#> Loading required package: parsnip
library(survival)
cox_mod <-
cox_reg() %>%
set_engine("survival") %>%
fit(Surv(time, status) ~ age + ph.ecog, data = lung)
pred_vals <- predict(cox_mod, new_data = lung, type = "survival", .time = 100:200)
pred_vals
#> # A tibble: 228 x 1
#> .pred_survival
#> <list<tbl_df[,2]>>
#> 1 [101 × 2]
#> 2 [101 × 2]
#> 3 [101 × 2]
#> 4 [101 × 2]
#> 5 [101 × 2]
#> 6 [101 × 2]
#> 7 [101 × 2]
#> 8 [101 × 2]
#> 9 [101 × 2]
#> 10 [101 × 2]
#> # … with 218 more rows
pred_vals$.pred_survival[[1]]
#> # A tibble: 101 x 2
#> .time .pred_survival
#> <chr> <dbl>
#> 1 100 0.855
#> 2 101 0.855
#> 3 102 0.855
#> 4 103 0.855
#> 5 104 0.855
#> 6 105 0.850
#> 7 106 0.850
#> 8 107 0.840
#> 9 108 0.840
#> 10 109 0.840
#> # … with 91 more rows
tidyr::unnest(pred_vals, cols = c(.pred_survival))
#> # A tibble: 23,028 x 2
#> .time .pred_survival
#> <chr> <dbl>
#> 1 100 0.855
#> 2 101 0.855
#> 3 102 0.855
#> 4 103 0.855
#> 5 104 0.855
#> 6 105 0.850
#> 7 106 0.850
#> 8 107 0.840
#> 9 108 0.840
#> 10 109 0.840
#> # … with 23,018 more rows Created on 2020-08-04 by the reprex package (v0.3.0) |
Package related to this PR is located here: https://github.com/EmilHvitfeldt/survnip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall; minor changes.
We might want to keep this in a branch; if we keep it in main
, people will think that everything is there.
Thanks!
@@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown") | |||
# ------------------------------------------------------------------------------ | |||
|
|||
pred_types <- | |||
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile") | |||
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile", | |||
"time", "survival", "linear_pred") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the infrastructure for a "hazard" type?
@@ -216,6 +223,54 @@ format_classprobs <- function(x) { | |||
x | |||
} | |||
|
|||
format_time <- function(x) { | |||
if (inherits(x, "tbl_spark")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the spark stuff from the survival bits
#' @export predict_linear_pred.model_fit | ||
#' @export | ||
predict_linear_pred.model_fit <- function(object, new_data, ...) { | ||
if (object$spec$mode != "censored regression") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd remove this check. We can do these types of predictions for all types of models.
) | ||
# | ||
# set_new_model("surv_reg") | ||
# set_model_mode("surv_reg", "regression") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd have to soft-deprecate these. I'll PR into your package to change the new surv_reg()
to survival_reg()
.
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
This PR extents
predict()
to allow fortype = time
andtype = survival
for risk prediction.