Skip to content

Created concrete types to call syntactic sugar on all kernels #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 30, 2020
9 changes: 6 additions & 3 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ printshifted(io::IO,κ::Kernel,shift::Int) = print(io,"$κ")
Base.show(io::IO,κ::Kernel) = print(io,nameof(typeof(κ)))

### Syntactic sugar for creating matrices and using kernel functions
for k in subtypes(BaseKernel)
if k ∈ [FBMKernel] continue end #for kernels without `metric` or `kappa`
function concretetypes(k, ktypes::Vector)
isempty(subtypes(k)) ? push!(ktypes, k) : concretetypes.(subtypes(k), Ref(ktypes))
return ktypes
end

for k in concretetypes(Kernel, [])
@eval begin
@inline (κ::$k)(d::Real) = kappa(κ,d) #TODO Add test
@inline (κ::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
@inline (κ::$k)(X::AbstractMatrix{T}, Y::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, Y, obsdim=obsdim)
@inline (κ::$k)(X::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, obsdim=obsdim)
Expand Down
50 changes: 22 additions & 28 deletions src/kernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,30 @@ For `h=1/2`, this is the Wiener Kernel, for `h>1/2`, the increments are
positively correlated and for `h<1/2` the increments are negatively correlated.
"""
struct FBMKernel{T<:Real} <: BaseKernel
h::T
h::Vector{T}
function FBMKernel(; h::T=0.5) where {T<:Real}
@assert h<=1.0 && h>=0.0 "FBMKernel: Given Hurst index h is invalid."
return new{T}(h)
@assert 0.0 <= h <= 1.0 "FBMKernel: Given Hurst index h is invalid."
return new{T}([h])
end
end

Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = $(k.h))")
Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = $(first(k.h)))")

const sqroundoff = 1e-15

_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h)/2

function kernelmatrix(κ::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
modX = sum(abs2, X; dims = 3 - obsdim)
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
modX = sum(abs2, X; dims = feature_dim(obsdim))
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
return _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
end

function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
modX = sum(abs2, X; dims = 3 - obsdim)
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
modX = sum(abs2, X; dims = feature_dim(obsdim))
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
K .= _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
return K
end
Expand All @@ -43,9 +45,9 @@ function kernelmatrix(
obsdim::Int = defaultobs,
)
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
modX = sum(abs2, X, dims=3-obsdim)
modY = sum(abs2, Y, dims=3-obsdim)
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
modX = sum(abs2, X, dims = feature_dim(obsdim))
modY = sum(abs2, Y, dims = feature_dim(obsdim))
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
return _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
end

Expand All @@ -57,9 +59,9 @@ function kernelmatrix!(
obsdim::Int = defaultobs,
)
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
modX = sum(abs2, X, dims=3-obsdim)
modY = sum(abs2, Y, dims=3-obsdim)
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
modX = sum(abs2, X, dims = feature_dim(obsdim))
modY = sum(abs2, Y, dims = feature_dim(obsdim))
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
K .= _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
return K
end
Expand All @@ -72,23 +74,15 @@ function _kernel(
obsdim::Int = defaultobs
)
@assert length(x) == length(y) "x and y don't have the same dimension!"
return κ(x,y)
return kappa(κ, x, y)
end

#Syntactic Sugar
function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
function kappa(κ::FBMKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
modX = sum(abs2, x)
modY = sum(abs2, y)
modXY = sqeuclidean(x, y)
return (modX^κ.h + modY^κ.h - modXY^κ.h)/2
modXY = evaluate(SqEuclidean(sqroundoff), x, y)
h = first(κ.h)
return (modX^h + modY^h - modXY^h)/2
end

(κ::FBMKernel)(x::Real, y::Real) = (abs2(x)^κ.h + abs2(y)^κ.h - abs2(x-y)^κ.h)/2

function (κ::FBMKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
return kernelmatrix(κ, X, Y, obsdim=obsdim)
end

function (κ::FBMKernel)(X::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
return kernelmatrix(κ, X, obsdim=obsdim)
end
(κ::FBMKernel)(x::Real, y::Real) = (abs2(x)^first(κ.h) + abs2(y)^first(κ.h) - abs2(x-y)^first(κ.h))/2
38 changes: 17 additions & 21 deletions src/matrix/kernelmatrix.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
```
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2)
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
```
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer = 2)
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer = 2)

In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will be overwritten with the kernel matrix.
"""
kernelmatrix!

Expand All @@ -21,7 +20,7 @@ function kernelmatrix!(
map!(x->kappa(κ,x),K,pairwise(metric(κ),X,dims=obsdim))
end

kernelmatrix!(K::Matrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
kernelmatrix!(K, kernel(κ), apply(κ.transform, X, obsdim = obsdim), obsdim = obsdim)

function kernelmatrix!(
Expand Down Expand Up @@ -61,13 +60,12 @@ _kernel(κ::TransformedKernel, x::AbstractVector, y::AbstractVector; obsdim::Int
_kernel(kernel(κ), apply(κ.transform, x), apply(κ.transform, y), obsdim = obsdim)

"""
```
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2)
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
```
kernelmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int = 2)

Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
`obsdim=1` means the matrix `X` (and `Y`) has size #samples x #dimension
`obsdim=2` means the matrix `X` (and `Y`) has size #dimension x #samples
`obsdim = 1` means the matrix `X` (and `Y`) has size #samples x #dimension
`obsdim = 2` means the matrix `X` (and `Y`) has size #dimension x #samples
"""
kernelmatrix

Expand Down Expand Up @@ -109,12 +107,11 @@ kernelmatrix(κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim
kernelmatrix(kernel(κ), apply(κ.transform, X, obsdim = obsdim), apply(κ.transform, Y, obsdim = obsdim), obsdim = obsdim)

"""
```
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
```
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)

Calculate the diagonal matrix of `X` with respect to kernel `κ`
`obsdim=1` means the matrix `X` has size #samples x #dimension
`obsdim=2` means the matrix `X` has size #dimension x #samples
`obsdim = 1` means the matrix `X` has size #samples x #dimension
`obsdim = 2` means the matrix `X` has size #dimension x #samples
"""
function kerneldiagmatrix(
κ::Kernel,
Expand All @@ -130,10 +127,9 @@ function kerneldiagmatrix(
end

"""
```
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int=2)
```
In place version of `kerneldiagmatrix`
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int = 2)

In place version of [`kerneldiagmatrix`](@ref)
"""
function kerneldiagmatrix!(
K::AbstractVector,
Expand Down
2 changes: 2 additions & 0 deletions src/trainable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import .Flux.trainable

trainable(k::ConstantKernel) = (k.c,)

trainable(k::FBMKernel) = (k.h,)

trainable(k::GammaExponentialKernel) = (k.γ,)

trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)
Expand Down
2 changes: 0 additions & 2 deletions test/kernels/custom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
KernelFunctions.metric(::MyKernel) = SqEuclidean()

# some syntactic sugar
(κ::MyKernel)(d::Real) = kappa(κ, d)
(κ::MyKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
(κ::MyKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X, Y; obsdim = obsdim)
(κ::MyKernel)(X::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X; obsdim = obsdim)
Expand All @@ -17,7 +16,6 @@ KernelFunctions.metric(::MyKernel) = SqEuclidean()
@test kernelmatrix(MyKernel(), [1 2; 3 4], [5 6; 7 8]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4], [5 6; 7 8])
@test kernelmatrix(MyKernel(), [1 2; 3 4]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4])

@test MyKernel()(3) == SqExponentialKernel()(3)
@test MyKernel()([1, 2], [3, 4]) == SqExponentialKernel()([1, 2], [3, 4])
@test MyKernel()([1 2; 3 4], [5 6; 7 8]) == SqExponentialKernel()([1 2; 3 4], [5 6; 7 8])
@test MyKernel()([1 2; 3 4]) == SqExponentialKernel()([1 2; 3 4])
Expand Down
17 changes: 17 additions & 0 deletions test/kernels/fbm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@testset "FBM" begin
h = 0.3
k = FBMKernel(h = h)
v1 = rand(3); v2 = rand(3)
@test k(v1,v2) ≈ (sqeuclidean(v1, zero(v1))^h + sqeuclidean(v2, zero(v2))^h - sqeuclidean(v1-v2, zero(v1-v2))^h)/2 atol=1e-5

# kernelmatrix tests
m1 = rand(3,3)
m2 = rand(3,3)
@test kernelmatrix(k, m1, m1) ≈ kernelmatrix(k, m1) atol=1e-5
@test kernelmatrix(k, m1, m2) ≈ k(m1, m2) atol=1e-5


x1 = rand()
x2 = rand()
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5
end
13 changes: 7 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,18 @@ using KernelFunctions: metric
end

@testset "kernels" begin
include(joinpath("kernels", "constant.jl"))
include(joinpath("kernels", "cosine.jl"))
include(joinpath("kernels", "exponential.jl"))
include(joinpath("kernels", "exponentiated.jl"))
include(joinpath("kernels", "fbm.jl"))
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "matern.jl"))
include(joinpath("kernels", "polynomial.jl"))
include(joinpath("kernels", "constant.jl"))
include(joinpath("kernels", "rationalquad.jl"))
include(joinpath("kernels", "exponentiated.jl"))
include(joinpath("kernels", "cosine.jl"))
include(joinpath("kernels", "transformedkernel.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "transformedkernel.jl"))

# Legacy tests that don't correspond to anything meaningful in src. Unclear how
# helpful these are.
Expand Down
20 changes: 15 additions & 5 deletions test/trainable.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
@testset "trainable" begin
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5

kc = ConstantKernel(c=c)
@test all(params(kc) .== params([c]))
km = MaternKernel(ν=ν)
@test all(params(km) .== params([ν]))
kl = LinearKernel(c=c)
@test all(params(kl) .== params([c]))

kfbm = FBMKernel(h = h)
@test all(params(kfbm) .== params([h]))

kge = GammaExponentialKernel(γ=γ)
@test all(params(kge) .== params([γ]))

kgr = GammaRationalQuadraticKernel(γ=γ, α=α)
@test all(params(kgr) .== params([α], [γ]))

kl = LinearKernel(c=c)
@test all(params(kl) .== params([c]))

km = MaternKernel(ν=ν)
@test all(params(km) .== params([ν]))

kp = PolynomialKernel(c=c, d=d)
@test all(params(kp) .== params([d], [c]))

kr = RationalQuadraticKernel(α=α)
@test all(params(kr) .== params([α]))

Expand Down