Skip to content

Commit bdc2854

Browse files
mattwarkentinhfricktopepo
authored
Adds support for flexsurvspline engine for survival_reg model spec (#831)
* Adds support for flexsurvspline engine for survival_reg model spec * Add PR number to NEWS * leave deprecated `surv_reg()` as is * make `k` tunable the arg name of `flexsurvspline()` is `k`, not `num_knots` the method does not need conditional registration because it's new and was never registered in tune * update `inst/models.tsv` so that `uses_extension()` in the engine docs works * update engine docs and knit * document() * update news, add contributor * import also the generic * remove engine arg from template * notes about case weights * render docs Co-authored-by: Hannah Frick <[email protected]> Co-authored-by: topepo <[email protected]>
1 parent 624dabc commit bdc2854

13 files changed

+250
-2
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ S3method(translate,survival_reg)
8989
S3method(translate,svm_linear)
9090
S3method(translate,svm_poly)
9191
S3method(translate,svm_rbf)
92+
S3method(tunable,survival_reg)
9293
S3method(type_sum,model_fit)
9394
S3method(type_sum,model_spec)
9495
S3method(update,C5_rules)
@@ -316,6 +317,7 @@ importFrom(generics,fit_xy)
316317
importFrom(generics,glance)
317318
importFrom(generics,required_pkgs)
318319
importFrom(generics,tidy)
320+
importFrom(generics,tunable)
319321
importFrom(generics,varying_args)
320322
importFrom(ggplot2,autoplot)
321323
importFrom(glue,glue_collapse)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Adds documentation and tuning infrastructure for the new `flexsurvspline` engine for the `survival_reg()` model specification from the `censored` package (@mattwarkentin, #831).
4+
35
* The matrix interface for fitting `fit_xy()` now works for the `"censored regression"` mode (#829).
46

57
* The `num_leaves` argument of `boost_tree()`s `lightgbm` engine (via the bonsai package) is now tunable.

R/parsnip-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
## usethis namespace: start
1111
#' @importFrom dplyr arrange bind_cols bind_rows collect full_join group_by
1212
#' @importFrom dplyr mutate pull rename select starts_with summarise tally
13-
#' @importFrom generics varying_args
13+
#' @importFrom generics tunable varying_args
1414
#' @importFrom glue glue_collapse
1515
#' @importFrom pillar type_sum
1616
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr

R/survival_reg_flexsurvspline.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' Flexible parametric survival regression
2+
#'
3+
#' [flexsurv::flexsurvspline()] fits a flexible parametric survival model.
4+
#'
5+
#' @includeRmd man/rmd/survival_reg_flexsurvspline.md details
6+
#'
7+
#' @name details_survival_reg_flexsurvspline
8+
#' @keywords internal
9+
NULL
10+
11+
# See inst/README-DOCS.md for a description of how these files are processed

R/tunable.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ brulee_multinomial_engine_args <-
203203
brulee_mlp_engine_args %>%
204204
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
205205

206+
flexsurvspline_engine_args <-
207+
tibble::tibble(
208+
name = c("k"),
209+
call_info = list(
210+
list(pkg = "dials", fun = "num_knots")
211+
),
212+
source = "model_spec",
213+
component = "survival_reg",
214+
component_id = "engine"
215+
)
216+
206217
# ------------------------------------------------------------------------------
207218

208219
# Lazily registered in .onLoad()
@@ -324,5 +335,14 @@ tunable_mlp <- function(x, ...) {
324335
res
325336
}
326337

338+
#' @export
339+
tunable.survival_reg <- function(x, ...) {
340+
res <- NextMethod()
341+
if (x$engine == "flexsurvspline") {
342+
res <- add_engine_parameters(res, flexsurvspline_engine_args)
343+
}
344+
res
345+
}
346+
327347
# nocov end
328348

inst/models.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
"surv_reg" "regression" "flexsurv" NA
124124
"surv_reg" "regression" "survival" NA
125125
"survival_reg" "censored regression" "flexsurv" "censored"
126+
"survival_reg" "censored regression" "flexsurvspline" "censored"
126127
"survival_reg" "censored regression" "survival" "censored"
127128
"svm_linear" "classification" "kernlab" NA
128129
"svm_linear" "classification" "LiblineaR" NA

man/details_auto_ml_h2o.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_survival_reg_flexsurv.Rd

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_survival_reg_flexsurvspline.Rd

Lines changed: 80 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/survival_reg_flexsurv.Rmd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ For this engine, stratification cannot be specified via [`strata()`], please see
4545
```{r child = "template-survival-mean.Rmd"}
4646
```
4747

48+
## Case weights
49+
50+
```{r child = "template-uses-case-weights.Rmd"}
51+
```
52+
4853
## Saving fitted model objects
4954

5055
```{r child = "template-butcher.Rmd"}

man/rmd/survival_reg_flexsurv.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ For this engine, stratification cannot be specified via [`strata()`], please see
4848

4949
Predictions of type `"time"` are predictions of the mean survival time.
5050

51+
## Case weights
52+
53+
54+
This model can utilize case weights during model fitting. To use them, see the documentation in [case_weights] and the examples on `tidymodels.org`.
55+
56+
The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that expect vectors of case weights.
57+
5158
## Saving fitted model objects
5259

5360

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
```{r, child = "aaa.Rmd", include = FALSE}
2+
```
3+
4+
`r descr_models("survival_reg", "flexsurvspline")`
5+
6+
## Tuning Parameters
7+
8+
This model has one engine-specific tuning parameter:
9+
10+
* `k`: Number of knots in the spline. The default is `k = 0`.
11+
12+
## Translation from parsnip to the original package
13+
14+
`r uses_extension("survival_reg", "flexsurvspline", "censored regression")`
15+
16+
```{r flexsurvspline-creg}
17+
library(censored)
18+
19+
survival_reg() %>%
20+
set_engine("flexsurvspline") %>%
21+
set_mode("censored regression") %>%
22+
translate()
23+
```
24+
25+
## Other details
26+
27+
The main interface for this model uses the formula method since the model specification typically involved the use of [survival::Surv()].
28+
29+
For this engine, stratification cannot be specified via [`strata()`], please see [flexsurv::flexsurvspline()] for alternative specifications.
30+
31+
```{r child = "template-survival-mean.Rmd"}
32+
```
33+
34+
## Case weights
35+
36+
```{r child = "template-uses-case-weights.Rmd"}
37+
```
38+
39+
40+
## Saving fitted model objects
41+
42+
```{r child = "template-butcher.Rmd"}
43+
```
44+
45+
46+
## References
47+
48+
- Jackson, C. 2016. `flexsurv`: A Platform for Parametric Survival Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
3+
4+
For this engine, there is a single mode: censored regression
5+
6+
## Tuning Parameters
7+
8+
This model has one engine-specific tuning parameter:
9+
10+
* `k`: Number of knots in the spline. The default is `k = 0`.
11+
12+
## Translation from parsnip to the original package
13+
14+
The **censored** extension package is required to fit this model.
15+
16+
17+
```r
18+
library(censored)
19+
20+
survival_reg() %>%
21+
set_engine("flexsurvspline") %>%
22+
set_mode("censored regression") %>%
23+
translate()
24+
```
25+
26+
```
27+
## Parametric Survival Regression Model Specification (censored regression)
28+
##
29+
## Computational engine: flexsurvspline
30+
##
31+
## Model fit template:
32+
## flexsurv::flexsurvspline(formula = missing_arg(), data = missing_arg(),
33+
## weights = missing_arg())
34+
```
35+
36+
## Other details
37+
38+
The main interface for this model uses the formula method since the model specification typically involved the use of [survival::Surv()].
39+
40+
For this engine, stratification cannot be specified via [`strata()`], please see [flexsurv::flexsurvspline()] for alternative specifications.
41+
42+
43+
44+
Predictions of type `"time"` are predictions of the mean survival time.
45+
46+
## Case weights
47+
48+
49+
This model can utilize case weights during model fitting. To use them, see the documentation in [case_weights] and the examples on `tidymodels.org`.
50+
51+
The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that expect vectors of case weights.
52+
53+
54+
## Saving fitted model objects
55+
56+
57+
This model object contains data that are not required to make predictions. When saving the model for the purpose of prediction, the size of the saved object might be substantially reduced by using functions from the [butcher](https://butcher.tidymodels.org) package.
58+
59+
60+
## References
61+
62+
- Jackson, C. 2016. `flexsurv`: A Platform for Parametric Survival Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.

0 commit comments

Comments
 (0)