Skip to content

Commit 252f595

Browse files
authored
Merge pull request #50 from JuliaGaussianProcesses/periodickernel
Add Periodic Kernel
2 parents b267b54 + ff08795 commit 252f595

File tree

14 files changed

+102
-13
lines changed

14 files changed

+102
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515

1616
[compat]
1717
Compat = "2.2, 3"
18-
Distances = "0.8"
18+
Distances = "0.8.2"
1919
Requires = "1.0.1"
2020
SpecialFunctions = "0.8, 0.9, 0.10"
2121
StatsBase = "0.32, 0.33"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ LinearKernel
2828
PolynomialKernel
2929
RationalQuadraticKernel
3030
GammaRationalQuadraticKernel
31+
PeriodicKernel
3132
ZeroKernel
3233
ConstantKernel
3334
WhiteKernel

docs/src/kernels.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ The [Polynomial Kernel](@ref KernelFunctions.PolynomialKernel) is defined as
9191
k(x,x';c,d) = \left(\langle x,x'\rangle + c\right)^d
9292
```
9393

94+
## Periodic Kernels
95+
96+
### PeriodicKernel
97+
98+
```math
99+
k(x,x';r) = \exp\left(-0.5 \sum_i (sin (π(x_i - x'_i))/r_i)^2\right)
100+
```
101+
94102
## Constant Kernels
95103

96104
### ConstantKernel

src/KernelFunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export MaternKernel, Matern32Kernel, Matern52Kernel
1616
export LinearKernel, PolynomialKernel
1717
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1818
export MahalanobisKernel, GaborKernel, PiecewisePolynomialKernel
19+
export PeriodicKernel
1920
export KernelSum, KernelProduct
2021
export TransformedKernel, ScaledKernel
2122

@@ -44,9 +45,10 @@ abstract type BaseKernel <: Kernel end
4445
include("utils.jl")
4546
include("distances/dotproduct.jl")
4647
include("distances/delta.jl")
48+
include("distances/sinus.jl")
4749
include("transform/transform.jl")
4850

49-
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha","fbm","gabor","piecewisepolynomial"]
51+
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha","fbm","gabor","periodic","piecewisepolynomial"]
5052
include(joinpath("kernels",k*".jl"))
5153
end
5254
include("kernels/transformedkernel.jl")

src/distances/delta.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ end
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
8-
return a==b
8+
return a == b
99
end
1010

11-
@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
12-
@inline (dist::Delta)(a::Number,b::Number) = a==b
11+
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
12+
@inline (dist::Delta)(a::Number,b::Number) = a == b

src/distances/dotproduct.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
struct DotProduct <: Distances.PreMetric
2-
end
1+
struct DotProduct <: Distances.PreMetric end
2+
# struct DotProduct <: Distances.UnionSemiMetric end
33

4-
@inline function Distances._evaluate(::DotProduct,a::AbstractVector{T},b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return dot(a,b)
99
end
1010

11-
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
12-
@inline (dist::DotProduct)(a::Number,b::Number) = a*b
11+
@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
12+
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
13+
@inline (dist::DotProduct)(a::Number,b::Number) = a * b

src/distances/sinus.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
struct Sinus{T} <: Distances.SemiMetric
2+
# struct Sinus{T} <: Distances.UnionSemiMetric
3+
r::Vector{T}
4+
end
5+
6+
Distances.parameters(d::Sinus) = d.r
7+
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
8+
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
9+
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
10+
11+
@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
12+
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
13+
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
14+
end
15+
return sum(abs2, sinpi.(a - b) ./ d.r)
16+
end

src/kernels/periodic.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
PeriodicKernel(r::AbstractVector)
3+
PeriodicKernel(dims::Int)
4+
PeriodicKernel(T::DataType, dims::Int)
5+
6+
Periodic Kernel as described in http://www.inference.org.uk/mackay/gpB.pdf eq. 47.
7+
```
8+
κ(x,y) = exp( - 0.5 sum_i(sin (π(x_i - y_i))/r_i))
9+
```
10+
"""
11+
struct PeriodicKernel{T} <: BaseKernel
12+
r::Vector{T}
13+
function PeriodicKernel(; r::AbstractVector{T} = ones(Float64, 1)) where {T<:Real}
14+
@assert all(r .> 0)
15+
new{T}(r)
16+
end
17+
end
18+
19+
PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)
20+
21+
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
22+
23+
metric::PeriodicKernel) = Sinus.r)
24+
25+
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)
26+
27+
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = $(length.r))")

src/trainable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ trainable(k::MaternKernel) = (k.ν,)
1414

1515
trainable(k::LinearKernel) = (k.c,)
1616

17+
trainable(k::PeriodicKernel) = (k.r,)
18+
1719
trainable(k::PolynomialKernel) = (k.d, k.c)
1820

1921
trainable(k::RationalQuadraticKernel) = (k.α,)

src/zygote_adjoints.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
2-
dot(x,y), Δ -> begin
3-
(nothing, Δ.*y, Δ.*x)
2+
dot(x, y), Δ -> begin
3+
(nothing, Δ .* y, Δ .* x)
44
end
55
end
6+
7+
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
8+
# d = evaluate(s, x, y)
9+
# s = sum(sin.(π*(x-y)))
10+
# d, Δ -> begin
11+
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d)
12+
# end
13+
# end

test/distances/sinus.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "sinus" begin
2+
A = rand(10)
3+
B = rand(10)
4+
p = rand(10)
5+
d = KernelFunctions.Sinus(p)
6+
@test Distances.parameters(d) == p
7+
@test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p))
8+
@test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
9+
end

test/kernels/periodic.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@testset "Periodic Kernel" begin
2+
x = rand()*2; v1 = rand(3); v2 = rand(3);
3+
r = rand(3)
4+
k = PeriodicKernel(r = r)
5+
@test kappa(k, x) exp(-0.5x)
6+
@test k(v1, v2) exp(-0.5 * sum(abs2, sinpi.(v1 - v2) ./ r))
7+
@test k(v1, v2) == k(v2, v1)
8+
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
9+
@test_nowarn println(k)
10+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ using KernelFunctions: metric
4949
@testset "distances" begin
5050
include(joinpath("distances", "dotproduct.jl"))
5151
include(joinpath("distances", "delta.jl"))
52+
include(joinpath("distances", "sinus.jl"))
5253
end
5354

5455
@testset "transform" begin
@@ -71,6 +72,7 @@ using KernelFunctions: metric
7172
include(joinpath("kernels", "kernelproduct.jl"))
7273
include(joinpath("kernels", "kernelsum.jl"))
7374
include(joinpath("kernels", "matern.jl"))
75+
include(joinpath("kernels", "periodic.jl"))
7476
include(joinpath("kernels", "polynomial.jl"))
7577
include(joinpath("kernels", "piecewisepolynomial.jl"))
7678
include(joinpath("kernels", "rationalquad.jl"))

test/trainable.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "trainable" begin
2-
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5
2+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5; r = rand(3)
33

44
kc = ConstantKernel(c=c)
55
@test all(params(kc) .== params([c]))
@@ -22,6 +22,9 @@
2222
kp = PolynomialKernel(c=c, d=d)
2323
@test all(params(kp) .== params([d], [c]))
2424

25+
kpe = PeriodicKernel(r = r)
26+
@test all(params(kpe) .== params(r))
27+
2528
kr = RationalQuadraticKernel=α)
2629
@test all(params(kr) .== params([α]))
2730

0 commit comments

Comments
 (0)