Skip to content

Commit 59214c9

Browse files
authored
Merge pull request #455 from tidymodels/proportional_hazards
Model spec for proportional hazards models
2 parents 66aa91d + e52aff2 commit 59214c9

File tree

6 files changed

+281
-0
lines changed

6 files changed

+281
-0
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ S3method(print,model_spec)
5555
S3method(print,multinom_reg)
5656
S3method(print,nearest_neighbor)
5757
S3method(print,nullmodel)
58+
S3method(print,proportional_hazards)
5859
S3method(print,rand_forest)
5960
S3method(print,surv_reg)
6061
S3method(print,survival_reg)
@@ -95,6 +96,7 @@ S3method(update,mars)
9596
S3method(update,mlp)
9697
S3method(update,multinom_reg)
9798
S3method(update,nearest_neighbor)
99+
S3method(update,proportional_hazards)
98100
S3method(update,rand_forest)
99101
S3method(update,surv_reg)
100102
S3method(update,survival_reg)
@@ -176,6 +178,7 @@ export(predict_survival.model_fit)
176178
export(predict_time)
177179
export(predict_time.model_fit)
178180
export(prepare_data)
181+
export(proportional_hazards)
179182
export(rand_forest)
180183
export(repair_call)
181184
export(req_pkgs)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
* New model specification `survival_reg()` for the new mode `"censored regression"`. (#444)
1010

11+
* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451).
12+
1113
# parsnip 0.1.5
1214

1315
* An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

R/proportional_hazards.R

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#' General Interface for Proportional Hazards Models
2+
#'
3+
#' `proportional_hazards()` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using different packages
5+
#' in R. The main arguments for the model are:
6+
#' \itemize{
7+
#' \item \code{penalty}: The total amount of regularization
8+
#' in the model. Note that this must be zero for some engines.
9+
#' \item \code{mixture}: The mixture amounts of different types of
10+
#' regularization (see below). Note that this will be ignored for some engines.
11+
#' }
12+
#' These arguments are converted to their specific names at the
13+
#' time that the model is fit. Other options and arguments can be
14+
#' set using `set_engine()`. If left to their defaults
15+
#' here (`NULL`), the values are taken from the underlying model
16+
#' functions. If parameters need to be modified, `update()` can be used
17+
#' in lieu of recreating the object from scratch.
18+
#'
19+
#' @param mode A single character string for the type of model.
20+
#' Possible values for this model are "unknown", or "censored regression".
21+
#' @inheritParams linear_reg
22+
#'
23+
#' @details
24+
#' Proportional hazards models include the Cox model.
25+
#' For `proportional_hazards()`, the mode will always be "censored regression".
26+
#'
27+
#' @examples
28+
#' show_engines("proportional_hazards")
29+
#'
30+
#' @export
31+
proportional_hazards <- function(mode = "censored regression",
32+
penalty = NULL,
33+
mixture = NULL) {
34+
35+
args <- list(
36+
penalty = enquo(penalty),
37+
mixture = enquo(mixture)
38+
)
39+
40+
new_model_spec(
41+
"proportional_hazards",
42+
args = args,
43+
eng_args = NULL,
44+
mode = mode,
45+
method = NULL,
46+
engine = NULL
47+
)
48+
}
49+
50+
#' @export
51+
print.proportional_hazards <- function(x, ...) {
52+
cat("Proportional Hazards Model Specification (", x$mode, ")\n\n", sep = "")
53+
model_printer(x, ...)
54+
55+
if (!is.null(x$method$fit$args)) {
56+
cat("Model fit template:\n")
57+
print(show_call(x))
58+
}
59+
60+
invisible(x)
61+
}
62+
63+
# ------------------------------------------------------------------------------
64+
65+
#' @param object A proportional hazards model specification.
66+
#' @param ... Not used for `update()`.
67+
#' @param fresh A logical for whether the arguments should be
68+
#' modified in-place of or replaced wholesale.
69+
#' @examples
70+
#' model <- proportional_hazards(penalty = 10, mixture = 0.1)
71+
#' model
72+
#' update(model, penalty = 1)
73+
#' update(model, penalty = 1, fresh = TRUE)
74+
#' @method update proportional_hazards
75+
#' @rdname proportional_hazards
76+
#' @export
77+
update.proportional_hazards <- function(object,
78+
parameters = NULL,
79+
penalty = NULL,
80+
mixture = NULL,
81+
fresh = FALSE, ...) {
82+
83+
eng_args <- update_engine_parameters(object$eng_args, ...)
84+
85+
if (!is.null(parameters)) {
86+
parameters <- check_final_param(parameters)
87+
}
88+
args <- list(
89+
penalty = enquo(penalty),
90+
mixture = enquo(mixture)
91+
)
92+
93+
args <- update_main_parameters(args, parameters)
94+
95+
if (fresh) {
96+
object$args <- args
97+
object$eng_args <- eng_args
98+
} else {
99+
null_args <- map_lgl(args, null_value)
100+
if (any(null_args))
101+
args <- args[!null_args]
102+
if (length(args) > 0)
103+
object$args[names(args)] <- args
104+
if (length(eng_args) > 0)
105+
object$eng_args[names(eng_args)] <- eng_args
106+
}
107+
108+
new_model_spec(
109+
"proportional_hazards",
110+
args = object$args,
111+
eng_args = object$eng_args,
112+
mode = object$mode,
113+
method = NULL,
114+
engine = object$engine
115+
)
116+
}

R/proportional_hazards_data.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
# parnip just contains the model specification, the engines are the censored package.
3+
4+
set_new_model("proportional_hazards")
5+
set_model_mode("proportional_hazards", "censored regression")

man/proportional_hazards.Rd

Lines changed: 78 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
2+
test_that("primary arguments", {
3+
new_empty_quosure <- function(expr) {
4+
rlang::new_quosure(expr, env = rlang::empty_env())
5+
}
6+
7+
ph_penalty <- proportional_hazards(penalty = 0.05)
8+
expect_equal(
9+
ph_penalty$args,
10+
list(penalty = new_empty_quosure(0.05),
11+
mixture = new_empty_quosure(NULL))
12+
)
13+
14+
ph_mixture <- proportional_hazards(mixture = 0.34)
15+
expect_equal(
16+
ph_mixture$args,
17+
list(penalty = new_empty_quosure(NULL),
18+
mixture = new_empty_quosure(0.34))
19+
)
20+
21+
ph_mixture_v <- proportional_hazards(mixture = varying())
22+
expect_equal(
23+
ph_mixture_v$args,
24+
list(penalty = new_empty_quosure(NULL),
25+
mixture = new_empty_quosure(varying()))
26+
)
27+
})
28+
29+
test_that("printing", {
30+
expect_output(
31+
print(proportional_hazards()),
32+
"Proportional Hazards Model Specification \\(censored regression\\)"
33+
)
34+
})
35+
36+
test_that("updating", {
37+
new_empty_quosure <- function(expr) {
38+
rlang::new_quosure(expr, env = rlang::empty_env())
39+
}
40+
41+
basic <- proportional_hazards()
42+
43+
update_num <- update(basic, penalty = 0.05)
44+
expect_equal(
45+
update_num$args,
46+
list(penalty = new_empty_quosure(0.05),
47+
mixture = new_empty_quosure(NULL))
48+
)
49+
50+
param_tibb <- tibble::tibble(penalty = 0.05)
51+
update_tibb <- update(basic, param_tibb)
52+
expect_equal(
53+
update_tibb$args,
54+
list(penalty = 0.05,
55+
mixture = new_empty_quosure(NULL))
56+
)
57+
58+
param_list <- as.list(param_tibb)
59+
update_list <- update(basic, param_list)
60+
expect_equal(
61+
update_list$args,
62+
list(penalty = 0.05,
63+
mixture = new_empty_quosure(NULL))
64+
)
65+
})
66+
67+
68+
test_that("bad input", {
69+
expect_error(proportional_hazards(mode = ", classification"))
70+
})
71+
72+
test_that("wrong fit interface", {
73+
expect_error(
74+
proportional_hazards() %>% fit_xy(),
75+
"must use the formula interface"
76+
)
77+
})

0 commit comments

Comments
 (0)