Skip to content

Commit 6fc3a6f

Browse files
committed
Merge branch 'master-dev'
2 parents c9411ee + 33fc054 commit 6fc3a6f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1201
-291
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@ version = "0.2.0"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
7+
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
910
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1011
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
12+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1113

1214
[compat]
1315
FiniteDifferences = ">= 0.7.2"
1416

1517
[extras]
1618
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
19+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1720
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1821
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1923
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2024

2125
[targets]
22-
test = ["FiniteDifferences", "Random", "Test"]
26+
test = ["FiniteDifferences", "Tracker", "ForwardDiff", "Random", "Test"]

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
[![Build Status](https://travis-ci.org/theogf/KernelFunctions.jl.svg?branch=master)](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl)
2+
[![Coverage Status](https://coveralls.io/repos/github/theogf/KernelFunctions.jl/badge.svg?branch=master)](https://coveralls.io/github/theogf/KernelFunctions.jl?branch=master)
23
[![Documentation](https://img.shields.io/badge/docs-dev-blue.svg)](https://theogf.github.io/KernelFunctions.jl/dev/)
3-
# KernelFunctions.jl (WIP)
4-
Julia Package for kernel functions for machine learning
4+
# KernelFunctions.jl
5+
## Kernel functions for machine learning
6+
7+
KernelFunctions.jl provide a flexible and complete framework for kernel functions, pretransforming the input data.
8+
9+
The aim is to make the API as model-agnostic as possible while still being user-friendly.
510

611
## Objectives (by priority)
712
- ARD Kernels

dev/debugAD.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
99
l = 0.1
1010
vl = l*ones(dims[1])
1111
testfunction(k,A,B) = det(kernelmatrix(k,A,B))
12-
testfunction(k,A) = sum(kernelmatrix(k,A))
12+
testfunction(k,A) = sum(kernelmatrix(k,A,obsdim=2))
1313
k = MaternKernel(vl)
1414
KernelFunctions.kappa(k,3)
1515
testfunction(SqExponentialKernel(vl),A)
1616
testfunction(MaternKernel(vl),A)
17-
@which kernelmatrix(MaternKernel(vl),A,B)
17+
kernelmatrix(MaternKernel(vl),A)
1818
#For debugging
1919
@info "Running Zygote gradients"
2020
Zygote.refresh()
@@ -40,10 +40,10 @@ Zygote.gradient(x->kernelmatrix(MaternKernel(x,1.0),A)[1],l)
4040

4141
@info "Running ForwardDiff gradients"
4242
## ForwardDiff
43-
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl) #
44-
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
4543
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A),vl) #
4644
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A),vl) #
45+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl) #
46+
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
4747
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A,B),[l])
4848
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A,B),[l])
4949
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A),[l])

docs/create_kernel_plots.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Plots; pyplot();
2+
using Distributions
3+
using LinearAlgebra
4+
using KernelFunctions
5+
# Translational invariants kernels
6+
7+
default(lw=3.0,titlefontsize=28,tickfontsize=18)
8+
9+
x₀ = 0.0; l=0.1
10+
n_grid = 101
11+
fill(x₀,n_grid,1)
12+
xrange = reshape(collect(range(-3,3,length=n_grid)),:,1)
13+
for k in [SqExponentialKernel,ExponentialKernel]
14+
K = kernelmatrix(k(),xrange,obsdim=1)
15+
v = rand(MvNormal(K+1e-7I))
16+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
17+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
18+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
19+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
20+
end
21+
22+
for k in [GammaExponentialKernel(1.0,1.5)]
23+
sparse =1
24+
while !isposdef(kernelmatrix(k,xrange*sparse,obsdim=1) + 1e-5I); sparse += 1; end
25+
v = rand(MvNormal(kernelmatrix(k,xrange*sparse,obsdim=1)+1e-7I))
26+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
27+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_GammaExponentialKernel.png"))
28+
plot(xrange,kernel.(k,x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
29+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_GammaExponentialKernel.png"))
30+
end
31+
32+
for k in [MaternKernel,Matern32Kernel,Matern52Kernel]
33+
K = kernelmatrix(k(),xrange,obsdim=1)
34+
v = rand(MvNormal(K+1e-7I))
35+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
36+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
37+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
38+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
39+
end
40+
41+
42+
for k in [RationalQuadraticKernel]
43+
K = kernelmatrix(k(),xrange,obsdim=1)
44+
v = rand(MvNormal(K+1e-7I))
45+
plot(xrange,v,lab="",title="f(x)",framestyle=:none) |> display
46+
savefig(joinpath(@__DIR__,"src","assets","GP_sample_$(k).png"))
47+
plot(xrange,kernel.(k(),x₀,xrange),lab="",ylims=(0,1.1),title="k(0,x)") |> display
48+
savefig(joinpath(@__DIR__,"src","assets","kappa_function_$(k).png"))
49+
end

docs/make.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ using KernelFunctions
44
makedocs(
55
sitename = "KernelFunctions",
66
format = Documenter.HTML(),
7-
modules = [KernelFunctions]
7+
modules = [KernelFunctions],
8+
pages = ["Home"=>"index.md",
9+
"User Guide" => "userguide.md",
10+
"Kernel Functions"=>"kernels.md",
11+
"Transform"=>"transform.md",
12+
"Metrics"=>"metrics.md",
13+
"Theory"=>"theory.md",
14+
"API"=>"api.md"]
815
)
916

1017
# Documenter can also automatically deploy documentation to gh-pages.

docs/src/api.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# API Library
2+
3+
---
4+
```@contents
5+
Pages = ["api.md"]
6+
```
7+
8+
```@meta
9+
CurrentModule = KernelFunctions
10+
```
11+
12+
## Module
13+
```@docs
14+
KernelFunctions
15+
```
16+
17+
## Kernel Functions
18+
19+
```@docs
20+
SqExponentialKernel
21+
Exponential
22+
GammaExponentialKernel
23+
ExponentiatedKernel
24+
MaternKernel
25+
Matern32Kernel
26+
Matern52Kernel
27+
LinearKernel
28+
PolynomialKernel
29+
RationalQuadraticKernel
30+
GammaRationalQuadraticKernel
31+
```
32+
33+
## Kernel Combinations
34+
35+
```@docs
36+
KernelSum
37+
KernelProduct
38+
```
39+
40+
## Transforms
41+
42+
```@docs
43+
Transform
44+
IdentityTransform
45+
ScaleTransform
46+
LowRankTransform
47+
FunctionTransform
48+
```
49+
50+
## Functions
51+
52+
```@docs
53+
kernelmatrix
54+
kernelmatrix!
55+
kerneldiagmatrix
56+
kerneldiagmatrix!
57+
transform
58+
```
59+
60+
61+
## Index
62+
63+
```@index
64+
Pages = ["api.md"]
65+
Module = ["KernelFunctions"]
66+
Order = [:type, :function]
67+
```

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# KernelFunctions.jl
22

3-
Documentation for KernelFunctions.jl
3+
Model agnostic kernel functions compatible with automatic differentiation
44

55
*** In Construction ***

docs/src/kernels.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
```@meta
2+
CurrentModule = KernelFunctions
3+
```
4+
5+
## Exponential Kernels
6+
7+
```@docs
8+
ExponentialKernel
9+
SqExponentialKernel
10+
GammaExponentialKernel
11+
```
12+
13+
## Matern Kernels
14+
15+
```@docs
16+
MaternKernel
17+
Matern32Kernel
18+
Matern52Kernel
19+
```
20+
21+
## Polynomial Kernels
22+
23+
```@docs
24+
LinearKernel
25+
PolynomialKernel
26+
```
27+
28+
## Constant Kernels
29+
30+
```@docs
31+
ConstantKernel
32+
WhiteKernel
33+
ZeroKernel
34+
```

docs/src/metrics.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Metrics
2+
3+
KernelFunctions.jl relies on [Distances.jl]() for computing the pairwise matrix.
4+
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
5+
However all kernels do not rely on distances metrics respecting all the definitions. That's why two additional metrics come with the package : `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`). If you want to create a new distance just implement the following :
6+
7+
```julia
8+
struct Delta <: Distances.PreMetric
9+
end
10+
11+
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
12+
@boundscheck if length(a) != length(b)
13+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
14+
end
15+
return a==b
16+
end
17+
18+
@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
19+
@inline (dist::Delta)(a::Number,b::Number) = a==b
20+
```

docs/src/theory.md

Whitespace-only changes.

docs/src/transform.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Transform
2+
3+
`Transform` is the object that takes care of transforming the input data before distances are being computed. It can be as standard as `IdentityTransform` returning the same input, can be a scalar with `ScaleTransform` multiplying the vectors by a scalar or a vector.
4+
There is a more general `Transform`: `FunctionTransform` that uses a function and apply it on each vector via `mapslices`.
5+
You can also create a pipeline of `Transform` via `TransformChain`. For example `LowRankTransform(rand(10,5))∘ScaleTransform(2.0)`.
6+
7+
One apply a transformation on a matrix or a vector via `transform(t::Transform,v::AbstractVecOrMat)`
8+
9+
## Transforms :
10+
11+
```@docs
12+
IdentityTransform
13+
ScaleTransform
14+
LowRankTransform
15+
FunctionTransform
16+
ChainTransform
17+
```

docs/src/userguide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Building kernel and matrices easily!

src/KernelFunctions.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,40 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4-
export Kernel, SqExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
3+
export kernel, kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4+
export Kernel
5+
export ConstantKernel, WhiteKernel, ZeroKernel
6+
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
7+
export ExponentiatedKernel
8+
export MaternKernel, Matern32Kernel, Matern52Kernel
9+
export LinearKernel, PolynomialKernel
10+
export RationalQuadraticKernel, GammaRationalQuadraticKernel
11+
export KernelSum, KernelProduct
12+
513

6-
export Transform, ScaleTransform
714

815
using Distances, LinearAlgebra
916
using Zygote: @adjoint
1017
using SpecialFunctions: lgamma, besselk
1118
using StatsFuns: logtwo
1219

1320
const defaultobs = 2
14-
abstract type Kernel{T,Tr} end
1521

16-
include("zygote_rules.jl")
22+
# include("zygote_rules.jl")
1723
include("utils.jl")
24+
include("distances/dotproduct.jl")
25+
include("distances/delta.jl")
1826
include("transform/transform.jl")
19-
include("kernelmatrix.jl")
2027

21-
kernels = ["sqexponential","matern"]
28+
29+
abstract type Kernel{T,Tr<:Transform} end
30+
31+
kernels = ["exponential","matern","polynomial","constant","rationalquad","exponentiated"]
2232
for k in kernels
2333
include(joinpath("kernels",k*".jl"))
2434
end
35+
include("kernelmatrix.jl")
36+
include("kernels/kernelsum.jl")
37+
include("kernels/kernelproduct.jl")
2538

2639
include("generic.jl")
2740

src/distances/delta.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
struct Delta <: Distances.PreMetric
2+
end
3+
4+
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
5+
@boundscheck if length(a) != length(b)
6+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
7+
end
8+
return a==b
9+
end
10+
11+
@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
12+
@inline (dist::Delta)(a::Number,b::Number) = a==b

src/distances/dotproduct.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
struct DotProduct <: Distances.PreMetric
2+
end
3+
4+
@inline function Distances._evaluate(::DotProduct,a::AbstractVector{T},b::AbstractVector{T}) where {T}
5+
@boundscheck if length(a) != length(b)
6+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
7+
end
8+
return dot(a,b)
9+
end
10+
11+
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
12+
@inline (dist::DotProduct)(a::Number,b::Number) = a*b

src/generic.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
1-
21
@inline metric::Kernel) = κ.metric
3-
kernels =
4-
for k in [:SqExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
5-
eval(quote
2+
3+
## Allows to iterate over kernels
4+
Base.length(::Kernel) = 1
5+
6+
Base.iterate(k::Kernel) = (k,nothing)
7+
Base.iterate(k::Kernel, ::Any) = nothing
8+
9+
### Syntactic sugar for creating matrices and using kernel functions
10+
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
11+
@eval begin
612
@inline::$k)(d::Real) = kappa(κ,d)
7-
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
8-
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9-
@inline::$k)(x::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,obsdim=obsdim)
10-
end)
13+
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate.metric,transform(κ,x),transform(κ,y)))
14+
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
15+
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)
16+
end
1117
end
12-
### Transform generics
1318

19+
### Transform generics
1420
@inline transform::Kernel) = κ.transform
1521
@inline transform::Kernel,x::AbstractVecOrMat) = transform.transform,x)
1622
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform.transform,x,obsdim)
23+
24+
## Constructors for kernels without parameters
25+
for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52Kernel,:ExponentiatedKernel]
26+
@eval begin
27+
$kernel::T=1.0) where {T<:Real} = $kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
28+
$kernel::A) where {A<:AbstractVector{<:Real}} = $kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
29+
$kernel(t::Tr) where {Tr<:Transform} = $kernel{eltype(t),Tr}(t)
30+
end
31+
end

0 commit comments

Comments
 (0)