Skip to content

Commit 9cc0dc7

Browse files
committed
2 parents 1d2516e + 687b251 commit 9cc0dc7

File tree

8 files changed

+804
-1
lines changed

8 files changed

+804
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,5 @@ Suggests:
5252
MASS,
5353
nlme,
5454
modeldata,
55+
LiblineaR,
5556
Matrix

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ S3method(print,nearest_neighbor)
5353
S3method(print,nullmodel)
5454
S3method(print,rand_forest)
5555
S3method(print,surv_reg)
56+
S3method(print,svm_linear)
5657
S3method(print,svm_poly)
5758
S3method(print,svm_rbf)
5859
S3method(req_pkgs,model_fit)
@@ -74,6 +75,7 @@ S3method(translate,multinom_reg)
7475
S3method(translate,nearest_neighbor)
7576
S3method(translate,rand_forest)
7677
S3method(translate,surv_reg)
78+
S3method(translate,svm_linear)
7779
S3method(translate,svm_poly)
7880
S3method(translate,svm_rbf)
7981
S3method(type_sum,model_fit)
@@ -88,6 +90,7 @@ S3method(update,multinom_reg)
8890
S3method(update,nearest_neighbor)
8991
S3method(update,rand_forest)
9092
S3method(update,surv_reg)
93+
S3method(update,svm_linear)
9194
S3method(update,svm_poly)
9295
S3method(update,svm_rbf)
9396
S3method(varying_args,model_spec)
@@ -182,6 +185,7 @@ export(show_fit)
182185
export(show_model_info)
183186
export(stan_conf_int)
184187
export(surv_reg)
188+
export(svm_linear)
185189
export(svm_poly)
186190
export(svm_rbf)
187191
export(tidy)

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# parsnip (development version)
22

3-
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN.
3+
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
4+
5+
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine. (#424)
46

57
# parsnip 0.1.5
68

R/svm_linear.R

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#' General interface for linear support vector machines
2+
#'
3+
#' `svm_linear()` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using
5+
#' different packages in R or via Spark. The main arguments for the
6+
#' model are:
7+
#' \itemize{
8+
#' \item \code{cost}: The cost of predicting a sample within or on the
9+
#' wrong side of the margin.
10+
#' \item \code{margin}: The epsilon in the SVM insensitive loss function
11+
#' (regression only)
12+
#' }
13+
#' These arguments are converted to their specific names at the
14+
#' time that the model is fit. Other options and arguments can be
15+
#' set using `set_engine()`. If left to their defaults
16+
#' here (`NULL`), the values are taken from the underlying model
17+
#' functions. If parameters need to be modified, `update()` can be used
18+
#' in lieu of recreating the object from scratch.
19+
#'
20+
#' @inheritParams boost_tree
21+
#' @param mode A single character string for the type of model.
22+
#' Possible values for this model are "unknown", "regression", or
23+
#' "classification".
24+
#' @param cost A positive number for the cost of predicting a sample within
25+
#' or on the wrong side of the margin
26+
#' @param margin A positive number for the epsilon in the SVM insensitive
27+
#' loss function (regression only)
28+
#' @details
29+
#' The model can be created using the `fit()` function using the
30+
#' following _engines_:
31+
#' \itemize{
32+
#' \item \pkg{R}: `"LiblineaR"` (the default)
33+
#' }
34+
#'
35+
#'
36+
#' @includeRmd man/rmd/svm-linear.Rmd details
37+
#'
38+
#' @importFrom purrr map_lgl
39+
#' @seealso [fit()]
40+
#' @examples
41+
#' show_engines("svm_linear")
42+
#'
43+
#' svm_linear(mode = "classification")
44+
#' # Parameters can be represented by a placeholder:
45+
#' svm_linear(mode = "regression", cost = varying())
46+
#' @export
47+
48+
svm_linear <-
49+
function(mode = "unknown",
50+
cost = NULL, margin = NULL) {
51+
52+
args <- list(
53+
cost = enquo(cost),
54+
margin = enquo(margin)
55+
)
56+
57+
new_model_spec(
58+
"svm_linear",
59+
args = args,
60+
eng_args = NULL,
61+
mode = mode,
62+
method = NULL,
63+
engine = NULL
64+
)
65+
}
66+
67+
#' @export
68+
print.svm_linear <- function(x, ...) {
69+
cat("Linear Support Vector Machine Specification (", x$mode, ")\n\n", sep = "")
70+
model_printer(x, ...)
71+
72+
if(!is.null(x$method$fit$args)) {
73+
cat("Model fit template:\n")
74+
print(show_call(x))
75+
}
76+
invisible(x)
77+
}
78+
79+
# ------------------------------------------------------------------------------
80+
81+
#' @export
82+
#' @inheritParams update.boost_tree
83+
#' @param object A linear SVM model specification.
84+
#' @examples
85+
#' model <- svm_linear(cost = 3)
86+
#' model
87+
#' update(model, cost = 1)
88+
#' update(model, cost = 1, fresh = TRUE)
89+
#' @method update svm_linear
90+
#' @rdname svm_linear
91+
#' @export
92+
update.svm_linear <-
93+
function(object,
94+
parameters = NULL,
95+
cost = NULL, margin = NULL,
96+
fresh = FALSE,
97+
...) {
98+
99+
eng_args <- update_engine_parameters(object$eng_args, ...)
100+
101+
if (!is.null(parameters)) {
102+
parameters <- check_final_param(parameters)
103+
}
104+
105+
args <- list(
106+
cost = enquo(cost),
107+
margin = enquo(margin)
108+
)
109+
110+
args <- update_main_parameters(args, parameters)
111+
112+
if (fresh) {
113+
object$args <- args
114+
object$eng_args <- eng_args
115+
} else {
116+
null_args <- map_lgl(args, null_value)
117+
if (any(null_args))
118+
args <- args[!null_args]
119+
if (length(args) > 0)
120+
object$args[names(args)] <- args
121+
if (length(eng_args) > 0)
122+
object$eng_args[names(eng_args)] <- eng_args
123+
}
124+
125+
new_model_spec(
126+
"svm_linear",
127+
args = object$args,
128+
eng_args = object$eng_args,
129+
mode = object$mode,
130+
method = NULL,
131+
engine = object$engine
132+
)
133+
}
134+
135+
# ------------------------------------------------------------------------------
136+
137+
#' @export
138+
translate.svm_linear <- function(x, engine = x$engine, ...) {
139+
x <- translate.default(x, engine = engine, ...)
140+
141+
# slightly cleaner code using
142+
arg_vals <- x$method$fit$args
143+
arg_names <- names(arg_vals)
144+
145+
# add checks to error trap or change things for this method
146+
147+
if (x$engine == "LiblineaR") {
148+
149+
if (is_null(x$eng_args$type)) {
150+
liblinear_type <- NULL
151+
} else {
152+
liblinear_type <- quo_get_expr(x$eng_args$type)
153+
}
154+
155+
if (x$mode == "regression") {
156+
if (is_null(quo_get_expr(x$args$margin)))
157+
arg_vals$svr_eps <- 0.1
158+
if (!is_null(liblinear_type))
159+
if(!liblinear_type %in% 11:13)
160+
rlang::abort(
161+
paste0("The LiblineaR engine argument of `type` = ",
162+
liblinear_type,
163+
" does not correspond to an SVM regression model.")
164+
)
165+
} else if (x$mode == "classification") {
166+
if (!is_null(liblinear_type))
167+
if(!liblinear_type %in% 1:5)
168+
rlang::abort(
169+
paste0("The LiblineaR engine argument of `type` = ",
170+
liblinear_type,
171+
" does not correspond to an SVM classification model.")
172+
)
173+
}
174+
}
175+
176+
x$method$fit$args <- arg_vals
177+
178+
# worried about people using this to modify the specification
179+
x
180+
}
181+
182+
# ------------------------------------------------------------------------------
183+
184+
check_args.svm_linear <- function(object) {
185+
invisible(object)
186+
}
187+
188+
# ------------------------------------------------------------------------------
189+
190+
svm_linear_post <- function(results, object) {
191+
results$predictions
192+
}
193+

0 commit comments

Comments
 (0)