Skip to content

Commit fe45dab

Browse files
authored
Merge pull request #449 from tidymodels/survival_reg
Add model spec for `survival_reg`
2 parents 9fef469 + ac4486a commit fe45dab

File tree

6 files changed

+255
-0
lines changed

6 files changed

+255
-0
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ S3method(print,nearest_neighbor)
5757
S3method(print,nullmodel)
5858
S3method(print,rand_forest)
5959
S3method(print,surv_reg)
60+
S3method(print,survival_reg)
6061
S3method(print,svm_linear)
6162
S3method(print,svm_poly)
6263
S3method(print,svm_rbf)
@@ -96,6 +97,7 @@ S3method(update,multinom_reg)
9697
S3method(update,nearest_neighbor)
9798
S3method(update,rand_forest)
9899
S3method(update,surv_reg)
100+
S3method(update,survival_reg)
99101
S3method(update,svm_linear)
100102
S3method(update,svm_poly)
101103
S3method(update,svm_rbf)
@@ -198,6 +200,7 @@ export(show_fit)
198200
export(show_model_info)
199201
export(stan_conf_int)
200202
export(surv_reg)
203+
export(survival_reg)
201204
export(svm_linear)
202205
export(svm_poly)
203206
export(svm_rbf)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424) and the `kernlab` engine (#438), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429). These models can use sparse matrices via `fit_xy()` (#447).
88

9+
* New model specification `survival_reg()` for the new mode `"censored regression"`. (#444)
10+
911
# parsnip 0.1.5
1012

1113
* An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

R/survival_reg.R

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#' General Interface for Parametric Survival Models
2+
#'
3+
#' `survival_reg()` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using
5+
#' R. The main argument for the
6+
#' model is:
7+
#' \itemize{
8+
#' \item \code{dist}: The probability distribution of the outcome.
9+
#' }
10+
#' This argument is converted to its specific names at the
11+
#' time that the model is fit. Other options and argument can be
12+
#' set using `set_engine()`. If left to its default
13+
#' here (`NULL`), the value is taken from the underlying model
14+
#' functions.
15+
#'
16+
#' @inheritParams boost_tree
17+
#' @param mode A single character string for the type of model.
18+
#' The only possible value for this model is "censored regression".
19+
#' @param dist A character string for the outcome distribution. "weibull" is
20+
#' the default.
21+
#' @details
22+
#' The data given to the function are not saved and are only used
23+
#' to determine the _mode_ of the model. For `survival_reg()`,the
24+
#' mode will always be "censored regression".
25+
#'
26+
#' Since survival models typically involve censoring (and require the use of
27+
#' [survival::Surv()] objects), the [fit()] function will require that the
28+
#' survival model be specified via the formula interface.
29+
#'
30+
#' @seealso [fit()], [survival::Surv()]
31+
#' @examples
32+
#' survival_reg()
33+
#' # Parameters can be represented by a placeholder:
34+
#' survival_reg(dist = varying())
35+
#'
36+
#' @export
37+
survival_reg <- function(mode = "censored regression", dist = NULL) {
38+
39+
args <- list(
40+
dist = enquo(dist)
41+
)
42+
43+
new_model_spec(
44+
"survival_reg",
45+
args = args,
46+
eng_args = NULL,
47+
mode = mode,
48+
method = NULL,
49+
engine = NULL
50+
)
51+
}
52+
53+
#' @export
54+
print.survival_reg <- function(x, ...) {
55+
cat("Parametric Survival Regression Model Specification (", x$mode, ")\n\n", sep = "")
56+
model_printer(x, ...)
57+
58+
if (!is.null(x$method$fit$args)) {
59+
cat("Model fit template:\n")
60+
print(show_call(x))
61+
}
62+
63+
invisible(x)
64+
}
65+
66+
# ------------------------------------------------------------------------------
67+
68+
#' Update a Parametric Survival Regression Specification
69+
#'
70+
#' If parameters need to be modified, this function can be used
71+
#' in lieu of recreating the object from scratch.
72+
#'
73+
#' @inheritParams update.boost_tree
74+
#' @param object A survival regression model specification.
75+
#' @examples
76+
#' model <- survival_reg(dist = "weibull")
77+
#' model
78+
#' update(model, dist = "lnorm")
79+
#' @method update survival_reg
80+
#' @rdname survival_reg
81+
#' @export
82+
update.survival_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) {
83+
84+
eng_args <- update_engine_parameters(object$eng_args, ...)
85+
86+
if (!is.null(parameters)) {
87+
parameters <- check_final_param(parameters)
88+
}
89+
90+
args <- list(
91+
dist = enquo(dist)
92+
)
93+
94+
args <- update_main_parameters(args, parameters)
95+
96+
if (fresh) {
97+
object$args <- args
98+
object$eng_args <- eng_args
99+
} else {
100+
null_args <- map_lgl(args, null_value)
101+
if (any(null_args))
102+
args <- args[!null_args]
103+
if (length(args) > 0)
104+
object$args[names(args)] <- args
105+
if (length(eng_args) > 0)
106+
object$eng_args[names(eng_args)] <- eng_args
107+
}
108+
109+
new_model_spec(
110+
"survival_reg",
111+
args = object$args,
112+
eng_args = object$eng_args,
113+
mode = object$mode,
114+
method = NULL,
115+
engine = object$engine
116+
)
117+
}

R/survival_reg_data.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
set_new_model("survival_reg")
3+
set_model_mode("survival_reg", "censored regression")
4+
5+
# ------------------------------------------------------------------------------
6+
7+
# parnip just contains the model specification, the engines are the censored package.

man/survival_reg.Rd

Lines changed: 68 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-survival_reg.R

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
test_that("primary argument", {
3+
new_empty_quosure <- function(expr) {
4+
rlang::new_quosure(expr, env = rlang::empty_env())
5+
}
6+
7+
normal <- survival_reg(dist = "lnorm")
8+
expect_equal(
9+
normal$args,
10+
list(dist = new_empty_quosure("lnorm"))
11+
)
12+
13+
dist_v <- survival_reg(dist = varying())
14+
expect_equal(
15+
dist_v$args,
16+
list(dist = new_empty_quosure(varying()))
17+
)
18+
})
19+
20+
test_that("updating", {
21+
new_empty_quosure <- function(expr) {
22+
rlang::new_quosure(expr, env = rlang::empty_env())
23+
}
24+
25+
basic <- survival_reg()
26+
27+
update_chr <- update(basic, dist = "lnorm")
28+
expect_equal(
29+
update_chr$args,
30+
list(dist = new_empty_quosure("lnorm"))
31+
)
32+
33+
param_tibb <- tibble::tibble(dist = "weibull")
34+
update_tibb <- update(basic, param_tibb)
35+
expect_equal(
36+
update_tibb$args,
37+
list(dist = "weibull")
38+
)
39+
40+
param_list <- as.list(param_tibb)
41+
update_list <- update(basic, param_list)
42+
expect_equal(
43+
update_list$args,
44+
list(dist = "weibull")
45+
)
46+
47+
})
48+
49+
test_that("bad input", {
50+
expect_error(survival_reg(mode = ", classification"))
51+
})
52+
53+
test_that("wrong fit interface", {
54+
expect_error(
55+
survival_reg() %>% fit_xy(),
56+
"must use the formula interface"
57+
)
58+
})

0 commit comments

Comments
 (0)