Skip to content

Commit 32888d5

Browse files
sharanrydevmotionwilltebbutt
authored
Add Fractional Brownian motion kernel (#48)
* Add Fractional Brownian motion kernel * Add new line at the end * Add kernalmatrix function(s) * Cover more cases * Fix bug * Update src/kernels/fbm.jl Co-Authored-By: David Widmann <[email protected]> * Add tests * Add scalar case * Update src/kernels/fbm.jl Co-Authored-By: David Widmann <[email protected]> * Update src/kernels/fbm.jl Co-Authored-By: David Widmann <[email protected]> * Add more tests * Update src/kernels/fbm.jl Co-Authored-By: willtebbutt <[email protected]> * Update src/kernels/fbm.jl Co-Authored-By: willtebbutt <[email protected]> * Update src/kernels/fbm.jl Co-Authored-By: willtebbutt <[email protected]> * Apply suggestions from code review Co-Authored-By: willtebbutt <[email protected]> * Add kernelmatrix! for Fractional Brownian motion kernel * Make FBM functions actually mutable * Fix kernelmatrix! for FBMKernel * Add _kernel for FBMKernel to make kerneldiagmatrix work Co-authored-by: David Widmann <[email protected]> Co-authored-by: willtebbutt <[email protected]>
1 parent 1081e46 commit 32888d5

File tree

5 files changed

+114
-2
lines changed

5 files changed

+114
-2
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ include("distances/dotproduct.jl")
4545
include("distances/delta.jl")
4646
include("transform/transform.jl")
4747

48-
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha"]
48+
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha","fbm"]
4949
include(joinpath("kernels",k*".jl"))
5050
end
5151
include("kernels/transformedkernel.jl")

src/generic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Base.show(io::IO,κ::Kernel) = print(io,nameof(typeof(κ)))
1616

1717
### Syntactic sugar for creating matrices and using kernel functions
1818
for k in subtypes(BaseKernel)
19+
if k [FBMKernel] continue end #for kernels without `metric` or `kappa`
1920
@eval begin
2021
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
2122
@inline::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)

src/kernels/fbm.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
FBMKernel(; h::Real=0.5)
3+
4+
Fractional Brownian motion kernel with Hurst index h from (0,1) given by
5+
```
6+
κ(x,y) = ( |x|²ʰ + |y|²ʰ - |x-y|²ʰ ) / 2
7+
```
8+
9+
For h=1/2, this is the Wiener Kernel, for h>1/2, the increments are
10+
positively correlated and for h<1/2 the increments are negatively correlated.
11+
"""
12+
struct FBMKernel{T<:Real} <: BaseKernel
13+
h::T
14+
function FBMKernel(;h::T=0.5) where {T<:Real}
15+
@assert h<=1.0 && h>=0.0 "FBMKernel: Given Hurst index h is invalid."
16+
return new{T}(h)
17+
end
18+
end
19+
20+
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h)/2
21+
22+
function kernelmatrix::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
23+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
24+
modX = sum(abs2, X; dims = 3 - obsdim)
25+
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
26+
return _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
27+
end
28+
29+
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
30+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
31+
modX = sum(abs2, X; dims = 3 - obsdim)
32+
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
33+
K .= _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
34+
return K
35+
end
36+
37+
function kernelmatrix(
38+
κ::FBMKernel,
39+
X::AbstractMatrix,
40+
Y::AbstractMatrix;
41+
obsdim::Int = defaultobs,
42+
)
43+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
44+
modX = sum(abs2, X, dims=3-obsdim)
45+
modY = sum(abs2, Y, dims=3-obsdim)
46+
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
47+
return _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
48+
end
49+
50+
function kernelmatrix!(
51+
K::AbstractMatrix,
52+
κ::FBMKernel,
53+
X::AbstractMatrix,
54+
Y::AbstractMatrix;
55+
obsdim::Int = defaultobs,
56+
)
57+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
58+
modX = sum(abs2, X, dims=3-obsdim)
59+
modY = sum(abs2, Y, dims=3-obsdim)
60+
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
61+
K .= _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
62+
return K
63+
end
64+
65+
function _kernel::FBMKernel, x::Real, y::Real)
66+
_kernel(κ, [x], [y])
67+
end
68+
69+
## Apply kernel on two vectors ##
70+
function _kernel(
71+
κ::FBMKernel,
72+
x::AbstractVector,
73+
y::AbstractVector;
74+
obsdim::Int = defaultobs
75+
)
76+
@assert length(x) == length(y) "x and y don't have the same dimension!"
77+
return κ(x,y)
78+
end
79+
80+
#Syntactic Sugar
81+
function::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
82+
modX = sum(abs2, x)
83+
modY = sum(abs2, y)
84+
modXY = sqeuclidean(x, y)
85+
return (modX^κ.h + modY^κ.h - modXY^κ.h)/2
86+
end
87+
88+
::FBMKernel)(x::Real, y::Real) = (abs2(x)^κ.h + abs2(y)^κ.h - abs2(x-y)^κ.h)/2
89+
90+
function::FBMKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
91+
return kernelmatrix(κ, X, Y, obsdim=obsdim)
92+
end
93+
94+
function::FBMKernel)(X::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
95+
return kernelmatrix(κ, X, obsdim=obsdim)
96+
end

src/matrix/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ kernelmatrix!
99

1010

1111
function kernelmatrix!(
12-
K::Matrix,
12+
K::AbstractMatrix,
1313
κ::Kernel,
1414
X::AbstractMatrix;
1515
obsdim::Int = defaultobs

test/test_kernels.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
2626
@test kappa(k,0.5) == c
2727
end
2828
end
29+
@testset "FBM" begin
30+
k = FBMKernel(h=0.3)
31+
@test k(v1,v2) (sqeuclidean(v1, zero(v1))^0.3 + sqeuclidean(v2, zero(v2))^0.3 - sqeuclidean(v1-v2, zero(v1-v2))^0.3)/2 atol=1e-5
32+
33+
# kernelmatrix tests
34+
m1 = rand(3,3)
35+
m2 = rand(3,3)
36+
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
37+
@test kernelmatrix(k, m1, m2) k(m1, m2) atol=1e-5
38+
39+
40+
x1 = rand()
41+
x2 = rand()
42+
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
43+
end
2944
@testset "Cosine" begin
3045
k = CosineKernel()
3146
@test eltype(k) == Any

0 commit comments

Comments
 (0)