Skip to content

WIP: adding new prediction types for Survnip #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed

WIP: adding new prediction types for Survnip #359

wants to merge 8 commits into from

Conversation

EmilHvitfeldt
Copy link
Member

This PR extents predict() to allow for type = time and type = survival for risk prediction.

@topepo
Copy link
Member

topepo commented Aug 4, 2020

Looking back at my notes, I think I favor a mode of "censored regression" rather than "risk prediction". Technically, if we produce survival probabilities, those aren't really the risk.

The prediction format for survival probabilities would have rnow(input) == nrow(output). This means that the results should be nested tibbles (even without multi_predict()).

Here's an example using flexsurv:

library(tidymodels)
#> ── Attaching packages ───────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0          ✓ recipes   0.1.13    
#> ✓ dials     0.0.8          ✓ rsample   0.0.7     
#> ✓ dplyr     1.0.1          ✓ tibble    3.0.3     
#> ✓ ggplot2   3.3.2          ✓ tidyr     1.1.1     
#> ✓ infer     0.5.2          ✓ tune      0.1.1.9000
#> ✓ modeldata 0.0.2          ✓ workflows 0.1.2.9000
#> ✓ parsnip   0.1.3          ✓ yardstick 0.0.7.9000
#> ✓ purrr     0.3.4
#> ── Conflicts ──────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
library(flexsurv)
#> Loading required package: survival
data(ovarian)
fit <- flexsurvreg(formula = Surv(futime, fustat) ~ 1, data = ovarian, dist="weibull")

summary(fit, ovarian[1:2,], t = c(50, 100, 150)) %>% 
  map_dfr(~ .x) %>% 
  mutate(row = rep(1:2, each = 3)) %>% 
  dplyr::select(.time = time, .pred_survivial = est, row) %>% 
  group_nest(row, .key = ".pred") %>% 
  select(-row)
#> # A tibble: 2 x 1
#>                .pred
#>   <list<tbl_df[,2]>>
#> 1            [3 × 2]
#> 2            [3 × 2]

Created on 2020-08-04 by the reprex package (v0.3.0)

@EmilHvitfeldt
Copy link
Member Author

EmilHvitfeldt commented Aug 5, 2020

Alright! I have changed everything to "censored regression". We are technically doing right-censored regression. So if we add left-censored regression we would need to remember to clarify.

👍 on nested tibbles. Where would be the cleanest place to convert? doing parsnip::set_pred() with post argument or in format_survival()? It feels like we would need a post function almost no matter what so it might make sense to create a list of tibbles there and have format_survival() the full tibble.

@EmilHvitfeldt
Copy link
Member Author

Changes to predict_survival should be done now

library(survnip)
#> Loading required package: parsnip
library(survival)

cox_mod <-
  cox_reg() %>%
  set_engine("survival") %>%
  fit(Surv(time, status) ~ age + ph.ecog, data = lung)

pred_vals <- predict(cox_mod, new_data = lung, type = "survival", .time = 100:200)

pred_vals
#> # A tibble: 228 x 1
#>        .pred_survival
#>    <list<tbl_df[,2]>>
#>  1          [101 × 2]
#>  2          [101 × 2]
#>  3          [101 × 2]
#>  4          [101 × 2]
#>  5          [101 × 2]
#>  6          [101 × 2]
#>  7          [101 × 2]
#>  8          [101 × 2]
#>  9          [101 × 2]
#> 10          [101 × 2]
#> # … with 218 more rows

pred_vals$.pred_survival[[1]]
#> # A tibble: 101 x 2
#>    .time .pred_survival
#>    <chr>          <dbl>
#>  1 100            0.855
#>  2 101            0.855
#>  3 102            0.855
#>  4 103            0.855
#>  5 104            0.855
#>  6 105            0.850
#>  7 106            0.850
#>  8 107            0.840
#>  9 108            0.840
#> 10 109            0.840
#> # … with 91 more rows

tidyr::unnest(pred_vals, cols = c(.pred_survival))
#> # A tibble: 23,028 x 2
#>    .time .pred_survival
#>    <chr>          <dbl>
#>  1 100            0.855
#>  2 101            0.855
#>  3 102            0.855
#>  4 103            0.855
#>  5 104            0.855
#>  6 105            0.850
#>  7 106            0.850
#>  8 107            0.840
#>  9 108            0.840
#> 10 109            0.840
#> # … with 23,018 more rows

Created on 2020-08-04 by the reprex package (v0.3.0)

@EmilHvitfeldt
Copy link
Member Author

Package related to this PR is located here: https://github.com/EmilHvitfeldt/survnip

Copy link
Member

@topepo topepo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall; minor changes.

We might want to keep this in a branch; if we keep it in main, people will think that everything is there.

Thanks!

@@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
# ------------------------------------------------------------------------------

pred_types <-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
"time", "survival", "linear_pred")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the infrastructure for a "hazard" type?

@@ -216,6 +223,54 @@ format_classprobs <- function(x) {
x
}

format_time <- function(x) {
if (inherits(x, "tbl_spark"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the spark stuff from the survival bits

#' @export predict_linear_pred.model_fit
#' @export
predict_linear_pred.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "censored regression")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove this check. We can do these types of predictions for all types of models.

)
#
# set_new_model("surv_reg")
# set_model_mode("surv_reg", "regression")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd have to soft-deprecate these. I'll PR into your package to change the new surv_reg() to survival_reg().

EmilHvitfeldt added a commit to EmilHvitfeldt/parsnip that referenced this pull request Dec 3, 2020
@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 6, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants