Skip to content

Commit e962ac8

Browse files
committed
add proportional_hazards model spec (formerly censored::cox_reg())
1 parent fe45dab commit e962ac8

File tree

4 files changed

+202
-0
lines changed

4 files changed

+202
-0
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ S3method(print,model_spec)
5555
S3method(print,multinom_reg)
5656
S3method(print,nearest_neighbor)
5757
S3method(print,nullmodel)
58+
S3method(print,proportional_hazards)
5859
S3method(print,rand_forest)
5960
S3method(print,surv_reg)
6061
S3method(print,survival_reg)
@@ -95,6 +96,7 @@ S3method(update,mars)
9596
S3method(update,mlp)
9697
S3method(update,multinom_reg)
9798
S3method(update,nearest_neighbor)
99+
S3method(update,proportional_hazards)
98100
S3method(update,rand_forest)
99101
S3method(update,surv_reg)
100102
S3method(update,survival_reg)
@@ -176,6 +178,7 @@ export(predict_survival.model_fit)
176178
export(predict_time)
177179
export(predict_time.model_fit)
178180
export(prepare_data)
181+
export(proportional_hazards)
179182
export(rand_forest)
180183
export(repair_call)
181184
export(req_pkgs)

R/proportional_hazards.R

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#' General Interface for Proportional Hazards Models
2+
#'
3+
#' `proportional_hazards()` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using different packages
5+
#' in R. The main arguments for the model are:
6+
#' \itemize{
7+
#' \item \code{penalty}: The total amount of regularization
8+
#' in the model. Note that this must be zero for some engines.
9+
#' \item \code{mixture}: The mixture amounts of different types of
10+
#' regularization (see below). Note that this will be ignored for some engines.
11+
#' }
12+
#' These arguments are converted to their specific names at the
13+
#' time that the model is fit. Other options and arguments can be
14+
#' set using `set_engine()`. If left to their defaults
15+
#' here (`NULL`), the values are taken from the underlying model
16+
#' functions. If parameters need to be modified, `update()` can be used
17+
#' in lieu of recreating the object from scratch.
18+
#'
19+
#' @param mode A single character string for the type of model.
20+
#' Possible values for this model are "unknown", or "censored regression".
21+
#' @inheritParams linear_reg
22+
#'
23+
#' @details
24+
#' Proportional hazards models include the Cox model.
25+
#' For `proportional_hazards()`, the mode will always be "censored regression".
26+
#'
27+
#' @examples
28+
#' show_engines("proportional_hazards")
29+
#'
30+
#' @export
31+
proportional_hazards <- function(mode = "censored regression",
32+
penalty = NULL,
33+
mixture = NULL) {
34+
35+
args <- list(
36+
penalty = enquo(penalty),
37+
mixture = enquo(mixture)
38+
)
39+
40+
new_model_spec(
41+
"proportional_hazards",
42+
args = args,
43+
eng_args = NULL,
44+
mode = mode,
45+
method = NULL,
46+
engine = NULL
47+
)
48+
}
49+
50+
#' @export
51+
print.proportional_hazards <- function(x, ...) {
52+
cat("Proportional Hazards Model Specification (", x$mode, ")\n\n", sep = "")
53+
model_printer(x, ...)
54+
55+
if (!is.null(x$method$fit$args)) {
56+
cat("Model fit template:\n")
57+
print(show_call(x))
58+
}
59+
60+
invisible(x)
61+
}
62+
63+
# ------------------------------------------------------------------------------
64+
65+
#' @param object A proportional hazards model specification.
66+
#' @param ... Not used for `update()`.
67+
#' @param fresh A logical for whether the arguments should be
68+
#' modified in-place of or replaced wholesale.
69+
#' @examples
70+
#' model <- proportional_hazards(penalty = 10, mixture = 0.1)
71+
#' model
72+
#' update(model, penalty = 1)
73+
#' update(model, penalty = 1, fresh = TRUE)
74+
#' @method update proportional_hazards
75+
#' @rdname proportional_hazards
76+
#' @export
77+
update.proportional_hazards <- function(object,
78+
parameters = NULL,
79+
penalty = NULL,
80+
mixture = NULL,
81+
fresh = FALSE, ...) {
82+
83+
eng_args <- update_engine_parameters(object$eng_args, ...)
84+
85+
if (!is.null(parameters)) {
86+
parameters <- check_final_param(parameters)
87+
}
88+
args <- list(
89+
penalty = enquo(penalty),
90+
mixture = enquo(mixture)
91+
)
92+
93+
args <- update_main_parameters(args, parameters)
94+
95+
if (fresh) {
96+
object$args <- args
97+
object$eng_args <- eng_args
98+
} else {
99+
null_args <- map_lgl(args, null_value)
100+
if (any(null_args))
101+
args <- args[!null_args]
102+
if (length(args) > 0)
103+
object$args[names(args)] <- args
104+
if (length(eng_args) > 0)
105+
object$eng_args[names(eng_args)] <- eng_args
106+
}
107+
108+
new_model_spec(
109+
"proportional_hazards",
110+
args = object$args,
111+
eng_args = object$eng_args,
112+
mode = object$mode,
113+
method = NULL,
114+
engine = object$engine
115+
)
116+
}

R/proportional_hazards_data.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
# parnip just contains the model specification, the engines are the censored package.
3+
4+
set_new_model("proportional_hazards")
5+
set_model_mode("proportional_hazards", "censored regression")

man/proportional_hazards.Rd

Lines changed: 78 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)