Skip to content

Commit 33fc054

Browse files
committed
Improved documentation and tests
1 parent f90b50b commit 33fc054

File tree

10 files changed

+166
-10
lines changed

10 files changed

+166
-10
lines changed

docs/create_kernel_plots.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Plots; pyplot();
2+
using Distributions
3+
using LinearAlgebra
4+
using KernelFunctions
5+
# Translational invariants kernels
6+
7+
default(lw=3.0,titlefontsize=28,tickfontsize=18)
8+
9+
x₀ = 0.0; l=0.1
10+
n_grid = 101
11+
fill(x₀,n_grid,1)
12+
xrange = reshape(collect(range(-3,3,length=n_grid)),:,1)
13+
for k in [SqExponentialKernel,ExponentialKernel]
14+
K = kernelmatrix(k(),xrange,obsdim=1)
15+
v = rand(MvNormal(K+1e-7I))
16+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
17+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
18+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
19+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
20+
end
21+
22+
for k in [GammaExponentialKernel(1.0,1.5)]
23+
sparse =1
24+
while !isposdef(kernelmatrix(k,xrange*sparse,obsdim=1) + 1e-5I); sparse += 1; end
25+
v = rand(MvNormal(kernelmatrix(k,xrange*sparse,obsdim=1)+1e-7I))
26+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
27+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_GammaExponentialKernel.png"))
28+
plot(xrange,kernel.(k,x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
29+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_GammaExponentialKernel.png"))
30+
end
31+
32+
for k in [MaternKernel,Matern32Kernel,Matern52Kernel]
33+
K = kernelmatrix(k(),xrange,obsdim=1)
34+
v = rand(MvNormal(K+1e-7I))
35+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
36+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
37+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
38+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
39+
end
40+
41+
42+
for k in [RationalQuadraticKernel]
43+
K = kernelmatrix(k(),xrange,obsdim=1)
44+
v = rand(MvNormal(K+1e-7I))
45+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
46+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
47+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
48+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
49+
end

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ makedocs(
99
"User Guide" => "userguide.md",
1010
"Kernel Functions"=>"kernels.md",
1111
"Transform"=>"transform.md",
12+
"Metrics"=>"metrics.md",
13+
"Theory"=>"theory.md",
1214
"API"=>"api.md"]
1315
)
1416

docs/src/api.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# API Library
2+
3+
---
4+
```@contents
5+
Pages = ["api.md"]
6+
```
7+
8+
```@meta
9+
CurrentModule = KernelFunctions
10+
```
11+
12+
## Module
13+
```@docs
14+
KernelFunctions
15+
```
16+
17+
## Kernel Functions
18+
19+
```@docs
20+
SqExponentialKernel
21+
Exponential
22+
GammaExponentialKernel
23+
ExponentiatedKernel
24+
MaternKernel
25+
Matern32Kernel
26+
Matern52Kernel
27+
LinearKernel
28+
PolynomialKernel
29+
RationalQuadraticKernel
30+
GammaRationalQuadraticKernel
31+
```
32+
33+
## Kernel Combinations
34+
35+
```@docs
36+
KernelSum
37+
KernelProduct
38+
```
39+
40+
## Transforms
41+
42+
```@docs
43+
Transform
44+
IdentityTransform
45+
ScaleTransform
46+
LowRankTransform
47+
FunctionTransform
48+
```
49+
50+
## Functions
51+
52+
```@docs
53+
kernelmatrix
54+
kernelmatrix!
55+
kerneldiagmatrix
56+
kerneldiagmatrix!
57+
transform
58+
```
59+
60+
61+
## Index
62+
63+
```@index
64+
Pages = ["api.md"]
65+
Module = ["KernelFunctions"]
66+
Order = [:type, :function]
67+
```

docs/src/metrics.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Metrics
2+
3+
KernelFunctions.jl relies on [Distances.jl]() for computing the pairwise matrix.
4+
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
5+
However all kernels do not rely on distances metrics respecting all the definitions. That's why two additional metrics come with the package : `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`). If you want to create a new distance just implement the following :
6+
7+
```julia
8+
struct Delta <: Distances.PreMetric
9+
end
10+
11+
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
12+
@boundscheck if length(a) != length(b)
13+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
14+
end
15+
return a==b
16+
end
17+
18+
@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
19+
@inline (dist::Delta)(a::Number,b::Number) = a==b
20+
```

docs/src/theory.md

Whitespace-only changes.

docs/src/transform.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
There is a more general `Transform`: `FunctionTransform` that uses a function and apply it on each vector via `mapslices`.
55
You can also create a pipeline of `Transform` via `TransformChain`. For example `LowRankTransform(rand(10,5))∘ScaleTransform(2.0)`.
66

7+
One apply a transformation on a matrix or a vector via `transform(t::Transform,v::AbstractVecOrMat)`
8+
79
## Transforms :
810

911
```@docs
1012
IdentityTransform
1113
ScaleTransform
1214
LowRankTransform
1315
FunctionTransform
14-
TransformChain
16+
ChainTransform
1517
```

docs/src/userguide.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1 @@
11
# Building kernel and matrices easily!
2-
3-
## Creating a kernel function
4-
5-
##

src/transform/transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export Transform, IdentityTransform, ScaleTransform, LowRankTransform, FunctionTransform, ChainTransform
2-
2+
export transform
33

44
abstract type Transform end
55

test/test_kernelmatrix.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ k = SqExponentialKernel()
4040
w1 = 0.4; w2 = 1.2;
4141
ks2 = KernelSum([k1,k2],weights=[w1,w2])
4242
@test all(kernelmatrix(ks,A) .== kernelmatrix(k1,A) + kernelmatrix(k2,A))
43+
@test all(kernelmatrix(ks+k1,A) .≈ 2*kernelmatrix(k1,A) + kernelmatrix(k2,A))
44+
@test all(kernelmatrix(k1+ks,A) .≈ 2*kernelmatrix(k1,A) + kernelmatrix(k2,A))
4345
@test all(kernelmatrix(ks,A,B) .== kernelmatrix(k1,A,B) + kernelmatrix(k2,A,B))
4446
@test all(kerneldiagmatrix(ks,A) .== kerneldiagmatrix(k1,A) + kerneldiagmatrix(k2,A))
4547
@test all(kernelmatrix(ks2,A) .== w1*kernelmatrix(k1,A) + w2*kernelmatrix(k2,A))
@@ -48,9 +50,11 @@ k = SqExponentialKernel()
4850
k1 = SqExponentialKernel()
4951
k2 = LinearKernel()
5052
kp = k1 * k2
51-
@test all(kernelmatrix(kp,A) .== kernelmatrix(k1,A) .* kernelmatrix(k2,A))
52-
@test all(kernelmatrix(kp,A,B) .== kernelmatrix(k1,A,B) .* kernelmatrix(k2,A,B))
53-
@test all(kerneldiagmatrix(kp,A) .== kerneldiagmatrix(k1,A) .* kerneldiagmatrix(k2,A))
54-
@test all(kernelmatrix(kp,A) .== kernelmatrix(k1,A) .* kernelmatrix(k2,A))
53+
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
54+
@test all(kernelmatrix(kp*k1,A) .≈ kernelmatrix(k1,A).^2 .* kernelmatrix(k2,A))
55+
@test all(kernelmatrix(k1*kp,A) .≈ kernelmatrix(k1,A).^2 .* kernelmatrix(k2,A))
56+
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
57+
@test all(kernelmatrix(kp,A,B) .≈ kernelmatrix(k1,A,B) .* kernelmatrix(k2,A,B))
58+
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
5559
end
5660
end

test/test_kernels.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,20 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
188188
@test kappa(GammaRationalQuadraticKernel(1.0,a,1.0),x) kappa(RationalQuadraticKernel(1.0,a),x)
189189
end
190190
end
191+
@testset "KernelCombinations" begin
192+
k1 = LinearKernel()
193+
k2 = SqExponentialKernel()
194+
X = rand(2,2)
195+
@testset "KernelSum" begin
196+
k = k1 + k2
197+
@test KernelFunctions.metric(k) == [KernelFunctions.DotProduct(),KernelFunctions.SqEuclidean()]
198+
@test length(k) == 2
199+
@test transform(k) == [transform(k1),transform(k2)]
200+
@test transform(k,X) == [transform(k1,X),transform(k2,X)]
201+
@test transform(k,X,1) == [transform(k1,X,1),transform(k2,X,1)]
202+
end
203+
@testset "KernelProduct" begin
204+
205+
end
206+
end
191207
end

0 commit comments

Comments
 (0)