Skip to content

Commit a4572ad

Browse files
sharanrydevmotion
andauthored
Add i-times integrated Wiener Kernel (JuliaGaussianProcesses#77)
* Add WienerKernel * Fix bugs * Modify _wiener * Add kernel matrix functions * Add tests * Modify _wiener to be separate for each i * Style fix * Move wiener kernel to basekernels * Use @Assert * Fix docstring * Remove kappa and _kernel * Apply suggestions from code review Fix docstring and Base.show Co-Authored-By: David Widmann <[email protected]> * Remove _wiener Co-Authored-By: David Widmann <[email protected]> * Apply suggestions from code review * Make docstring more readable * Make docstring more readable * Update Base.show Co-Authored-By: David Widmann <[email protected]> * Update docstring * Update docstring * Remove test for k0(m1, m2) Co-Authored-By: David Widmann <[email protected]> * Add more tests * Fix tests Co-Authored-By: David Widmann <[email protected]> * Restrict line length Co-authored-by: David Widmann <[email protected]>
1 parent c006d47 commit a4572ad

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export transform
99
export duplicate, set! # Helpers
1010

1111
export Kernel
12-
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel
12+
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel
1313
export CosineKernel
1414
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
1515
export LaplacianKernel, ExponentialKernel, GammaExponentialKernel

src/basekernels/wiener.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
WienerKernel{i}()
3+
4+
i-times integrated Wiener process kernel function.
5+
6+
- For i=-1, this is just the white noise covariance, see [`WhiteKernel`](@ref).
7+
- For i= 0, this is the Wiener process covariance,
8+
- For i= 1, this is the integrated Wiener process covariance (velocity),
9+
- For i= 2, this is the twice-integrated Wiener process covariance (accel.),
10+
- For i= 3, this is the thrice-integrated Wiener process covariance,
11+
12+
where ``κᵢ`` is given by
13+
14+
```math
15+
κ₋₁(x, y) = δ(x, y)
16+
κ₀(x, y) = min(x, y)
17+
```
18+
and for ``i >= 1``,
19+
```math
20+
κᵢ(x, y) = 1 / aᵢ * min(x, y)^(2i + 1) + bᵢ * min(x, y)^(i + 1) * |x - y| * rᵢ(x, y),
21+
```
22+
with the coefficients ``aᵢ``, ``bᵢ`` and the residual ``rᵢ(x, y)`` defined as follows:
23+
```math
24+
a₁ = 3, b₁ = 1/2, r₁(x, y) = 1,
25+
a₂ = 20, b₂ = 1/12, r₂(x, y) = x + y - min(x, y) / 2,
26+
a₃ = 252, b₃ = 1/720, r₃(x, y) = 5 * max(x, y)² + 2 * x * z + 3 * min(x, y)²
27+
28+
```
29+
30+
# References:
31+
See the paper *Probabilistic ODE Solvers with Runge-Kutta Means* by Schober, Duvenaud and
32+
Hennig, NIPS, 2014, for more details.
33+
34+
"""
35+
struct WienerKernel{I} <: BaseKernel
36+
function WienerKernel{I}() where I
37+
@assert I (-1, 0, 1, 2, 3) "Invalid parameter i=$(I). Should be -1, 0, 1, 2 or 3."
38+
if I == -1
39+
return WhiteKernel()
40+
end
41+
return new{I}()
42+
end
43+
end
44+
45+
function WienerKernel(;i::Integer=0)
46+
return WienerKernel{i}()
47+
end
48+
49+
function (::WienerKernel{0})(x, y)
50+
X = sqrt(sum(abs2, x))
51+
Y = sqrt(sum(abs2, y))
52+
return min(X, Y)
53+
end
54+
55+
function (::WienerKernel{1})(x, y)
56+
X = sqrt(sum(abs2, x))
57+
Y = sqrt(sum(abs2, y))
58+
minXY = min(X, Y)
59+
return 1 / 3 * minXY^3 + 1 / 2 * minXY^2 * euclidean(x, y)
60+
end
61+
62+
function (::WienerKernel{2})(x, y)
63+
X = sqrt(sum(abs2, x))
64+
Y = sqrt(sum(abs2, y))
65+
minXY = min(X, Y)
66+
return 1 / 20 * minXY^5 + 1 / 12 * minXY^3 * euclidean(x, y) *
67+
( X + Y - 1 / 2 * minXY )
68+
end
69+
70+
function (::WienerKernel{3})(x, y)
71+
X = sqrt(sum(abs2, x))
72+
Y = sqrt(sum(abs2, y))
73+
minXY = min(X, Y)
74+
return 1 / 252 * minXY^7 + 1 / 720 * minXY^4 * euclidean(x, y) *
75+
( 5 * max(X, Y)^2 + 2 * X * Y + 3 * minXY^2 )
76+
end
77+
78+
Base.show(io::IO, ::WienerKernel{I}) where I = print(io, I, "-times integrated Wiener kernel")

test/basekernels/wiener.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testset "wiener" begin
2+
k_1 = WienerKernel(i=-1)
3+
@test typeof(k_1) <: WhiteKernel
4+
5+
k0 = WienerKernel()
6+
@test typeof(k0) <: WienerKernel{0}
7+
8+
k1 = WienerKernel(i=1)
9+
@test typeof(k1) <: WienerKernel{1}
10+
11+
k2 = WienerKernel(i=2)
12+
@test typeof(k2) <: WienerKernel{2}
13+
14+
k3 = WienerKernel(i=3)
15+
@test typeof(k3) <: WienerKernel{3}
16+
17+
@test_throws AssertionError WienerKernel(i=4)
18+
@test_throws AssertionError WienerKernel(i=-2)
19+
20+
v1 = rand(4)
21+
v2 = rand(4)
22+
23+
X = sqrt(sum(abs2, v1))
24+
Y = sqrt(sum(abs2, v2))
25+
minXY = min(X, Y)
26+
27+
@test k0(v1, v2) minXY
28+
@test k1(v1, v2) 1 / 3 * minXY^3 + 1 / 2 * minXY^2 * euclidean(v1, v2)
29+
@test k2(v1, v2) 1 / 20 * minXY^5 + 1 / 12 * minXY^3 * euclidean(v1, v2) *
30+
( X + Y - 1 / 2 * minXY )
31+
@test k3(v1, v2) 1 / 252 * minXY^7 + 1 / 720 * minXY^4 * euclidean(v1, v2) *
32+
( 5 * max(X, Y)^2 + 2 * X * Y + 3 * minXY^2 )
33+
34+
# kernelmatrix tests
35+
m1 = rand(3,4)
36+
m2 = rand(3,4)
37+
@test kernelmatrix(k0, m1, m1) kernelmatrix(k0, m1) atol=1e-5
38+
39+
K = zeros(4,4)
40+
kernelmatrix!(K,k0,m1,m2)
41+
@test K kernelmatrix(k0, m1, m2) atol=1e-5
42+
43+
V = zeros(4)
44+
kerneldiagmatrix!(V,k0,m1)
45+
@test V kerneldiagmatrix(k0,m1) atol=1e-5
46+
47+
x1 = rand()
48+
x2 = rand()
49+
@test kernelmatrix(k0, x1*ones(1,1), x2*ones(1,1))[1] k0(x1, x2) atol=1e-5
50+
@test kernelmatrix(k1, x1*ones(1,1), x2*ones(1,1))[1] k1(x1, x2) atol=1e-5
51+
@test kernelmatrix(k2, x1*ones(1,1), x2*ones(1,1))[1] k2(x1, x2) atol=1e-5
52+
@test kernelmatrix(k3, x1*ones(1,1), x2*ones(1,1))[1] k3(x1, x2) atol=1e-5
53+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ using KernelFunctions: metric, kappa
7676
include(joinpath("basekernels", "polynomial.jl"))
7777
include(joinpath("basekernels", "piecewisepolynomial.jl"))
7878
include(joinpath("basekernels", "rationalquad.jl"))
79+
include(joinpath("basekernels", "wiener.jl"))
7980
end
8081

8182
@testset "kernels" begin

0 commit comments

Comments
 (0)