Skip to content

Commit cf5c0a7

Browse files
willtebbutttheogfst--
authored
Use LIBSVM in SVM example (#341)
* Utilise LIBSVM * Lower-bound LIBSVM * Update examples/support-vector-machine/script.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Apply suggestions from code review Co-authored-by: st-- <[email protected]> Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: st-- <[email protected]>
1 parent b24c5b9 commit cf5c0a7

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

examples/support-vector-machine/Manifest.toml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,26 @@ version = "2.1.0+0"
283283
deps = ["ChainRulesCore", "Compat", "CompositionsBase", "Distances", "FillArrays", "Functors", "LinearAlgebra", "Random", "Requires", "SpecialFunctions", "StatsBase", "StatsFuns", "TensorCore", "Test", "ZygoteRules"]
284284
path = "../.."
285285
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
286-
version = "0.10.6"
286+
version = "0.10.8"
287287

288288
[[LAME_jll]]
289289
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
290290
git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c"
291291
uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
292292
version = "3.100.1+0"
293293

294+
[[LIBLINEAR]]
295+
deps = ["Libdl", "SparseArrays", "liblinear_jll"]
296+
git-tree-sha1 = "81e40115c23acca9dfa30944050096b958271e5a"
297+
uuid = "2d691ee1-e668-5016-a719-b2531b85e0f5"
298+
version = "0.6.0"
299+
300+
[[LIBSVM]]
301+
deps = ["LIBLINEAR", "LinearAlgebra", "ScikitLearnBase", "SparseArrays", "libsvm_jll"]
302+
git-tree-sha1 = "729ea2db931587c983d0ef6691b62de5005c5570"
303+
uuid = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
304+
version = "0.7.0"
305+
294306
[[LZO_jll]]
295307
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
296308
git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6"
@@ -587,6 +599,12 @@ version = "0.3.0+0"
587599
[[SHA]]
588600
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
589601

602+
[[ScikitLearnBase]]
603+
deps = ["LinearAlgebra", "Random", "Statistics"]
604+
git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f"
605+
uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
606+
version = "0.5.0"
607+
590608
[[Scratch]]
591609
deps = ["Dates"]
592610
git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda"
@@ -882,12 +900,24 @@ git-tree-sha1 = "7a5780a0d9c6864184b3a2eeeb833a0c871f00ab"
882900
uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280"
883901
version = "0.1.6+4"
884902

903+
[[liblinear_jll]]
904+
deps = ["Libdl", "Pkg"]
905+
git-tree-sha1 = "6a4a6a3697269cb2da57e698e9318972d88de0bb"
906+
uuid = "275f1f90-abd2-5ca1-9ad8-abd4e3d66eb7"
907+
version = "2.30.0+0"
908+
885909
[[libpng_jll]]
886910
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
887911
git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c"
888912
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
889913
version = "1.6.38+0"
890914

915+
[[libsvm_jll]]
916+
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
917+
git-tree-sha1 = "ac78676ee5b1707de969d68d0a39db71f222925d"
918+
uuid = "08558c22-525a-5d2a-acf6-0ac6658ffce4"
919+
version = "3.24.0+1"
920+
891921
[[libvorbis_jll]]
892922
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"]
893923
git-tree-sha1 = "c45f4e40e7aafe9d086379e5578947ec8b95a8fb"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
4+
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
67
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
78

89
[compat]
910
Distributions = "0.25"
1011
KernelFunctions = "0.10"
12+
LIBSVM = "0.7"
1113
Literate = "2"
1214
Plots = "1"
1315
julia = "1.3"
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# # Support Vector Machine
22
#
3-
# !!! warning
4-
# This example is under construction
53

6-
using KernelFunctions
74
using Distributions
8-
using Plots
9-
5+
using KernelFunctions
6+
using LIBSVM
107
using LinearAlgebra
8+
using Plots
119
using Random
1210

1311
## Set plotting theme
@@ -20,23 +18,29 @@ Random.seed!(1234);
2018
N = 100;
2119

2220
# Select randomly between two classes:
23-
y = rand([-1, 1], N);
21+
y_train = rand([-1, 1], N);
2422

2523
# Random attributes for both classes:
2624
X = Matrix{Float64}(undef, 2, N)
27-
rand!(MvNormal(randn(2), I), view(X, :, y .== 1))
28-
rand!(MvNormal(randn(2), I), view(X, :, y .== -1));
25+
rand!(MvNormal(randn(2), I), view(X, :, y_train .== 1))
26+
rand!(MvNormal(randn(2), I), view(X, :, y_train .== -1));
27+
x_train = ColVecs(X);
2928

3029
# Create a 2D grid:
31-
xgrid = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
32-
Xgrid = ColVecs(mapreduce(collect, hcat, Iterators.product(xgrid, xgrid)));
30+
test_range = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
31+
x_test = ColVecs(mapreduce(collect, hcat, Iterators.product(test_range, test_range)));
3332

3433
# Create kernel function:
3534
k = SqExponentialKernel() ScaleTransform(2.0)
3635

37-
# Optimal prediction:
38-
f(x, X, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ) * I) * y
36+
# [LIBSVM](https://github.com/JuliaML/LIBSVM.jl) can make use of a pre-computed kernel matrix.
37+
# KernelFunctions.jl can be used to produce that.
38+
# Precomputed matrix for training (corresponds to linear kernel)
39+
model = svmtrain(kernelmatrix(k, x_train), y_train; kernel=LIBSVM.Kernel.Precomputed)
40+
41+
# Precomputed matrix for prediction
42+
y_pr, _ = svmpredict(model, kernelmatrix(k, x_train, x_test));
3943

4044
# Compute prediction on a grid:
41-
contourf(xgrid, xgrid, f(Xgrid, ColVecs(X), k, 0.1))
42-
scatter!(X[1, :], X[2, :]; color=y, lab="data", widen=false)
45+
contourf(test_range, test_range, y_pr)
46+
scatter!(X[1, :], X[2, :]; color=y_train, lab="data", widen=false)

0 commit comments

Comments
 (0)