Skip to content

Commit d095935

Browse files
committed
Added plots for the README
1 parent 5808f9c commit d095935

File tree

10 files changed

+58
-8
lines changed

10 files changed

+58
-8
lines changed

README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,35 @@ KernelFunctions.jl provide a flexible and complete framework for kernel function
88

99
The aim is to make the API as model-agnostic as possible while still being user-friendly.
1010

11+
## Examples
12+
13+
```julia
14+
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
15+
# Set simple scaling of the data
16+
k₁ = SqExponentialKernel(1.0)
17+
K₁ = kernelmatrix(k,X,obsdim=1)
18+
19+
# Set a function transformation on the data
20+
k₂ = MaternKernel(FunctionTransform(x->sin.(x)))
21+
K₂ = kernelmatrix(k,X,obsdim=1)
22+
23+
# Set a matrix premultiplication on the data
24+
k₃ = PolynomialKernel(LowRankTransform(randn(4,1)),0.0,2.0)
25+
K₃ = kernelmatrix(k,X,obsdim=1)
26+
27+
# Add and sum kernels
28+
k₄ = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*k₂
29+
K₄ = kernelmatrix(k,X,obsdim=1)
30+
31+
heatmap([K₁,K₂,K₃,K₄],yflip=false,colorbar=false)
32+
```
33+
<p align=center>
34+
<img src="docs/src/assets/heatmap_combination.png" width=400px>
35+
</p>
36+
1137
## Objectives (by priority)
12-
- ARD Kernels
13-
- AD Compatible (Zygote, ForwardDiff, ReverseDiff)
14-
- Kernel sum and product
38+
- AD Compatibility (Zygote, ForwardDiff)
1539
- Toeplitz Matrices
1640
- BLAS backend
1741

18-
19-
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package
42+
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package.

docs/.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
build/
22
site/
3-
src/assets/*.png
43

54
#Temp to avoid to many changes

docs/create_kernel_plots.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@ x₀ = 0.0; l=0.1
1010
n_grid = 101
1111
fill(x₀,n_grid,1)
1212
xrange = reshape(collect(range(-3,3,length=n_grid)),:,1)
13+
14+
k = SqExponentialKernel(1.0)
15+
K1 = kernelmatrix(k,xrange,obsdim=1)
16+
p = heatmap(K1,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
17+
savefig(joinpath(@__DIR__,"src","assets","heatmap_sqexp.png"))
18+
19+
20+
k = Matern32Kernel(FunctionTransform(x->(sin.(x)).^2))
21+
K2 = kernelmatrix(k,xrange,obsdim=1)
22+
p = heatmap(K2,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
23+
savefig(joinpath(@__DIR__,"src","assets","heatmap_matern.png"))
24+
25+
26+
k = PolynomialKernel(LowRankTransform(randn(3,1)),2.0,0.0)
27+
K3 = kernelmatrix(k,xrange,obsdim=1)
28+
p = heatmap(K3,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
29+
savefig(joinpath(@__DIR__,"src","assets","heatmap_poly.png"))
30+
31+
k = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*Matern32Kernel(FunctionTransform(x->sin.(x)))
32+
K4 = kernelmatrix(k,xrange,obsdim=1)
33+
p = heatmap(K4,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
34+
savefig(joinpath(@__DIR__,"src","assets","heatmap_prodsum.png"))
35+
36+
plot(heatmap.([K1,K2,K3,K4],yflip=true,colorbar=false)...,layout=(2,2))
37+
savefig(joinpath(@__DIR__,"src","assets","heatmap_combination.png"))
38+
39+
1340
for k in [SqExponentialKernel,ExponentialKernel]
1441
K = kernelmatrix(k(),xrange,obsdim=1)
1542
v = rand(MvNormal(K+1e-7I))
67.8 KB
Loading

docs/src/assets/heatmap_matern.png

37.8 KB
Loading

docs/src/assets/heatmap_poly.png

17.8 KB
Loading

docs/src/assets/heatmap_prodsum.png

21.5 KB
Loading

docs/src/assets/heatmap_sqexp.png

7.43 KB
Loading

src/kernels/kernelsum.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
2525
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))
2626
Base.:+(k::Kernel,ks::KernelSum) = KernelSum(vcat(k,ks.kernels),weights=vcat(1.0,ks.weights))
2727
Base.:+(ks::KernelSum,k::Kernel) = KernelSum(vcat(ks.kernels,k),weights=vcat(ks.weights,1.0))
28-
Base.:*(w::Real,k::Kernel) = KernelSum([k],[w]) #TODO add tests
28+
Base.:*(w::Real,k::Kernel) = KernelSum([k],weights=[w]) #TODO add tests
29+
Base.:*(w::Real,k::KernelSum) = KernelSum(k.kernels,weights=w*k.weights) #TODO add tests
2930

3031

3132
Base.length(k::KernelSum) = length(k.kernels)

src/matrix/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function kernelmatrix(
8787
obsdim=defaultobs
8888
)
8989
if !check_dims(X,Y,feature_dim(obsdim),obsdim)
90-
throw(DimensionMismatch("X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"))
90+
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
9191
end
9292
_kernelmatrix(κ,X,Y,obsdim)
9393
end

0 commit comments

Comments
 (0)