Skip to content

Commit b69e8e1

Browse files
committed
improve writing
1 parent a64d1eb commit b69e8e1

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

examples/support-vector-machine/script.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# # Support Vector Machine
22
#
3+
# In this notebook we show how you can use KernelFunctions.jl to generate
4+
# kernel matrices for classification with a support vector machine, as
5+
# implemented by LIBSVM.
36

47
using Distributions
58
using KernelFunctions
@@ -16,21 +19,23 @@ Random.seed!(1234);
1619
# Number of samples per class:
1720
nin = nout = 50;
1821

19-
# Generate data
20-
## based on SciKit-Learn's sklearn.datasets.make_moons function
22+
# We generate data based on SciKit-Learn's sklearn.datasets.make_moons function:
23+
2124
class1x = cos.(range(0, π; length=nout))
2225
class1y = sin.(range(0, π; length=nout))
2326
class2x = 1 .- cos.(range(0, π; length=nin))
2427
class2y = 1 .- sin.(range(0, π; length=nin)) .- 0.5
2528
X = hcat(vcat(class1x, class2x), vcat(class1y, class2y))
2629
X .+= 0.1randn(size(X))
27-
x_train = RowVecs(X);
30+
x_train = RowVecs(X)
2831
y_train = vcat(fill(-1, nout), fill(1, nin));
2932

3033
# Create a 100×100 2D grid for evaluation:
3134
test_range = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
3235
x_test = ColVecs(mapreduce(collect, hcat, Iterators.product(test_range, test_range)));
3336

37+
# ## SVM model
38+
#
3439
# Create kernel function:
3540
k = SqExponentialKernel() ScaleTransform(1.5)
3641

@@ -43,12 +48,8 @@ model = svmtrain(kernelmatrix(k, x_train), y_train; kernel=LIBSVM.Kernel.Precomp
4348
# Precomputed matrix for prediction
4449
y_pred, _ = svmpredict(model, kernelmatrix(k, x_train, x_test));
4550

46-
# Compute prediction on a grid:
47-
plot(; lim=extrema(test_range))
51+
# Visualize prediction on a grid:
52+
plot(; lim=extrema(test_range), aspect_ratio=1)
4853
contourf!(test_range, test_range, y_pred; levels=1, color=cgrad(:redsblues), alpha=0.7)
49-
scatter!(
50-
X[y_train .== -1, 1], X[y_train .== -1, 2]; color=:red, label="class 1", widen=false
51-
)
52-
scatter!(
53-
X[y_train .== +1, 1], X[y_train .== +1, 2]; color=:blue, label="class 2", widen=false
54-
)
54+
scatter!(X[y_train .== -1, 1], X[y_train .== -1, 2]; color=:red, label="class 1")
55+
scatter!(X[y_train .== +1, 1], X[y_train .== +1, 2]; color=:blue, label="class 2")

0 commit comments

Comments
 (0)