Skip to content

Commit 65a5ab8

Browse files
authored
Add kernlab engine for svm_linear() (#438)
* Add kernlab as engine for svm_linear() * Tests for kernlab linear SVM * Redocument * Update NEWS for new engine
1 parent c8229a1 commit 65a5ab8

File tree

6 files changed

+436
-17
lines changed

6 files changed

+436
-17
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
44

5-
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429).
5+
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424) and the `kernlab` engine (#438), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429).
66

77
# parsnip 0.1.5
88

R/svm_linear.R

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#' The model can be created using the `fit()` function using the
3030
#' following _engines_:
3131
#' \itemize{
32-
#' \item \pkg{R}: `"LiblineaR"` (the default)
32+
#' \item \pkg{R}: `"LiblineaR"` (the default) or `"kernlab"`
3333
#' }
3434
#'
3535
#'
@@ -173,6 +173,16 @@ translate.svm_linear <- function(x, engine = x$engine, ...) {
173173
}
174174
}
175175

176+
if (x$engine == "kernlab") {
177+
178+
# unless otherwise specified, classification models predict probabilities
179+
if (x$mode == "classification" && !any(arg_names == "prob.model"))
180+
arg_vals$prob.model <- TRUE
181+
if (x$mode == "classification" && any(arg_names == "epsilon"))
182+
arg_vals$epsilon <- NULL
183+
184+
}
185+
176186
x$method$fit$args <- arg_vals
177187

178188
# worried about people using this to modify the specification
@@ -191,3 +201,7 @@ svm_linear_post <- function(results, object) {
191201
results$predictions
192202
}
193203

204+
svm_reg_linear_post <- function(results, object) {
205+
results[,1]
206+
}
207+

R/svm_linear_data.R

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ set_fit(
3636
protect = c("x", "y", "wi"),
3737
data = c(x = "data", y = "target"),
3838
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
39-
defaults = list(
40-
type = 11
41-
)
39+
defaults = list(type = 11)
4240
)
4341
)
4442

@@ -51,9 +49,7 @@ set_fit(
5149
data = c(x = "data", y = "target"),
5250
protect = c("x", "y", "wi"),
5351
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
54-
defaults = list(
55-
type = 1
56-
)
52+
defaults = list(type = 1)
5753
)
5854
)
5955

@@ -162,3 +158,158 @@ set_pred(
162158
newx = quote(new_data))
163159
)
164160
)
161+
162+
# ------------------------------------------------------------------------------
163+
164+
set_model_engine("svm_linear", "classification", "kernlab")
165+
set_model_engine("svm_linear", "regression", "kernlab")
166+
set_dependency("svm_linear", "kernlab", "kernlab")
167+
168+
set_model_arg(
169+
model = "svm_linear",
170+
eng = "kernlab",
171+
parsnip = "cost",
172+
original = "C",
173+
func = list(pkg = "dials", fun = "cost", range = c(-10, 5)),
174+
has_submodel = FALSE
175+
)
176+
177+
set_model_arg(
178+
model = "svm_linear",
179+
eng = "kernlab",
180+
parsnip = "margin",
181+
original = "epsilon",
182+
func = list(pkg = "dials", fun = "svm_margin"),
183+
has_submodel = FALSE
184+
)
185+
186+
set_fit(
187+
model = "svm_linear",
188+
eng = "kernlab",
189+
mode = "regression",
190+
value = list(
191+
interface = "formula",
192+
data = c(formula = "x", data = "data"),
193+
protect = c("x", "data"),
194+
func = c(pkg = "kernlab", fun = "ksvm"),
195+
defaults = list(kernel = "vanilladot")
196+
)
197+
)
198+
199+
set_fit(
200+
model = "svm_linear",
201+
eng = "kernlab",
202+
mode = "classification",
203+
value = list(
204+
interface = "formula",
205+
data = c(formula = "x", data = "data"),
206+
protect = c("x", "data"),
207+
func = c(pkg = "kernlab", fun = "ksvm"),
208+
defaults = list(kernel = "vanilladot")
209+
)
210+
)
211+
212+
set_encoding(
213+
model = "svm_linear",
214+
eng = "kernlab",
215+
mode = "regression",
216+
options = list(
217+
predictor_indicators = "none",
218+
compute_intercept = FALSE,
219+
remove_intercept = FALSE,
220+
allow_sparse_x = FALSE
221+
)
222+
)
223+
224+
set_pred(
225+
model = "svm_linear",
226+
eng = "kernlab",
227+
mode = "regression",
228+
type = "numeric",
229+
value = list(
230+
pre = NULL,
231+
post = svm_reg_linear_post,
232+
func = c(pkg = "kernlab", fun = "predict"),
233+
args =
234+
list(
235+
object = quote(object$fit),
236+
newdata = quote(new_data),
237+
type = "response"
238+
)
239+
)
240+
)
241+
242+
set_pred(
243+
model = "svm_linear",
244+
eng = "kernlab",
245+
mode = "regression",
246+
type = "raw",
247+
value = list(
248+
pre = NULL,
249+
post = NULL,
250+
func = c(pkg = "kernlab", fun = "predict"),
251+
args = list(object = quote(object$fit), newdata = quote(new_data))
252+
)
253+
)
254+
255+
set_encoding(
256+
model = "svm_linear",
257+
eng = "kernlab",
258+
mode = "classification",
259+
options = list(
260+
predictor_indicators = "none",
261+
compute_intercept = FALSE,
262+
remove_intercept = FALSE,
263+
allow_sparse_x = FALSE
264+
)
265+
)
266+
267+
set_pred(
268+
model = "svm_linear",
269+
eng = "kernlab",
270+
mode = "classification",
271+
type = "class",
272+
value = list(
273+
pre = NULL,
274+
post = NULL,
275+
func = c(pkg = "kernlab", fun = "predict"),
276+
args =
277+
list(
278+
object = quote(object$fit),
279+
newdata = quote(new_data),
280+
type = "response"
281+
)
282+
)
283+
)
284+
285+
set_pred(
286+
model = "svm_linear",
287+
eng = "kernlab",
288+
mode = "classification",
289+
type = "prob",
290+
value = list(
291+
pre = NULL,
292+
post = function(result, object) as_tibble(result),
293+
func = c(pkg = "kernlab", fun = "predict"),
294+
args =
295+
list(
296+
object = quote(object$fit),
297+
newdata = quote(new_data),
298+
type = "probabilities"
299+
)
300+
)
301+
)
302+
303+
set_pred(
304+
model = "svm_linear",
305+
eng = "kernlab",
306+
mode = "classification",
307+
type = "raw",
308+
value = list(
309+
pre = NULL,
310+
post = NULL,
311+
func = c(pkg = "kernlab", fun = "predict"),
312+
args = list(object = quote(object$fit), newdata = quote(new_data))
313+
)
314+
)
315+

man/rmd/svm-linear.Rmd

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@ predictions (e.g., accuracy and so on).
3131
This engine fits models that are L2-regularized for L2-loss. In the `LiblineaR`
3232
documentation, these are types 1 (classification) and 11 (regression).
3333

34+
## kernlab
35+
36+
```{r kernlab-reg}
37+
svm_linear() %>%
38+
set_engine("kernlab") %>%
39+
set_mode("regression") %>%
40+
translate()
41+
```
42+
43+
```{r kernlab-cls}
44+
svm_linear() %>%
45+
set_engine("kernlab") %>%
46+
set_mode("classification") %>%
47+
translate()
48+
```
49+
50+
`fit()` passes the data directly to `kernlab::ksvm()` so that its formula method can create dummy variables as-needed.
51+
52+
3453
## Parameter translations
3554

3655
The standardized parameter names in parsnip can be mapped to their original
@@ -44,6 +63,8 @@ get_defaults_svm_linear <- function() {
4463
~model, ~engine, ~parsnip, ~original, ~default,
4564
"svm_linear", "LiblineaR", "cost", "C", "1",
4665
"svm_linear", "LiblineaR", "margin", "svr_eps", "0.1",
66+
"svm_linear", "kernlab", "cost", "C", "1",
67+
"svm_linear", "kernlab", "margin", "epsilon", "0.1",
4768
)
4869
}
4970
convert_args("svm_linear")

man/svm_linear.Rd

Lines changed: 32 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)