Skip to content

Add Piecewise Polynomial Kernel #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export ExponentiatedKernel
export MaternKernel, Matern32Kernel, Matern52Kernel
export LinearKernel, PolynomialKernel
export RationalQuadraticKernel, GammaRationalQuadraticKernel
export MahalanobisKernel
export MahalanobisKernel, PiecewisePolynomialKernel
export KernelSum, KernelProduct
export TransformedKernel, ScaledKernel

Expand Down Expand Up @@ -46,7 +46,7 @@ include("distances/dotproduct.jl")
include("distances/delta.jl")
include("transform/transform.jl")

for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha","fbm"]
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha","fbm","piecewisepolynomial"]
include(joinpath("kernels",k*".jl"))
end
include("kernels/transformedkernel.jl")
Expand Down
118 changes: 118 additions & 0 deletions src/kernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
PiecewisePolynomialKernel{V}(maha::AbstractMatrix)

Piecewise Polynomial covariance function with compact support, V = 0,1,2,3.
The kernel functions are 2v times continuously differentiable and the corresponding
processes are hence v times mean-square differentiable. The kernel function is:
```math
κ(x,y) = max(1-r,0)^(j+V) * f(r,j) with j = floor(D/2)+V+1
```
where `r` is the Mahalanobis distance mahalanobis(x,y) with `maha` as the metric.

"""
struct PiecewisePolynomialKernel{V, A<:AbstractMatrix{<:Real}} <: BaseKernel
maha::A
function PiecewisePolynomialKernel{V}(maha::AbstractMatrix{<:Real}) where V
V in (0, 1, 2, 3) || error("Invalid paramter v=$(V). Should be 0, 1, 2 or 3.")
LinearAlgebra.checksquare(maha)
return new{V,typeof(maha)}(maha)
end
end

function PiecewisePolynomialKernel(;v::Integer=0, maha::AbstractMatrix{<:Real})
return PiecewisePolynomialKernel{v}(maha)
end

function _f(κ::PiecewisePolynomialKernel{V}, r, j) where V
if V==0
return 1
elseif V==1
return 1 + (j + 1) * r
elseif V==2
return 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r.^2
elseif V==3
return 1 + (j + 3) * r + (6 * j^2 + 36j + 45) / 15 * r.^2 +
(j^3 + 9 * j^2 + 23j + 15) / 15 * r.^3
else
error("Invalid paramter v=$(V). Should be 0,1,2 or 3.")
end
end

function _piecewisepolynomial(κ::PiecewisePolynomialKernel{V}, r, j) where V
return max(1 - r, 0)^(j + V) * _f(κ, r, j)
end

function kappa(
κ::PiecewisePolynomialKernel{V},
x::AbstractVector{<:Real},
y::AbstractVector{<:Real},
) where {V}
r = evaluate(metric(κ), x, y)
j = div(size(x, 2), 1) + V + 1
return _piecewisepolynomial(κ, r, j)
end

function _kernel(
κ::PiecewisePolynomialKernel,
x::AbstractVector,
y::AbstractVector;
obsdim::Int = defaultobs,
)
@assert length(x) == length(y) "x and y don't have the same dimension!"
return kappa(κ,x,y)
end

function kernelmatrix(
κ::PiecewisePolynomialKernel{V},
X::AbstractMatrix;
obsdim::Int = defaultobs
) where {V}
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X; dims=obsdim))
end

function _kernelmatrix(κ::PiecewisePolynomialKernel{V}, X, Y; obsdim) where {V}
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X, Y; dims=obsdim))
end

function kernelmatrix!(
K::AbstractMatrix,
κ::PiecewisePolynomialKernel{V},
X::AbstractMatrix;
obsdim::Int = defaultobs
) where {V}
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
throw(DimensionMismatch(
"Dimensions of the target array K $(size(K)) are not consistent with X " *
"$(size(X))",
))
end
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X; dims=obsdim))
end

function kernelmatrix!(
K::AbstractMatrix,
κ::PiecewisePolynomialKernel{V},
X::AbstractMatrix,
Y::AbstractMatrix;
obsdim::Int = defaultobs,
) where {V}
@assert obsdim ∈ [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
throw(DimensionMismatch(
"Dimensions $(size(K)) of the target array K are not consistent with X " *
"($(size(X))) and Y ($(size(Y)))",
))
end
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X, Y; dims=obsdim))
end

metric(κ::PiecewisePolynomialKernel) = Mahalanobis(κ.maha)

function Base.show(io::IO, κ::PiecewisePolynomialKernel{V}) where {V}
print(io, "Piecewise Polynomial Kernel (v = $(V), size(maha) = $(size(κ.maha))")
end
33 changes: 33 additions & 0 deletions test/kernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@testset "piecewisepolynomial" begin
v1 = rand(3)
v2 = rand(3)
m1 = rand(3, 4)
m2 = rand(3, 4)
maha = ones(3, 3)
k = PiecewisePolynomialKernel{3}(maha)

k2 = PiecewisePolynomialKernel(v=3, maha=maha)

@test k2(v1, v2) ≈ k(v1, v2) atol=1e-5

@test k(v1, v2) ≈ kappa(k, v1, v2) atol=1e-5
@test typeof(k(v1, v2)) <: Real
@test size(k(m1, m2)) == (4, 4)
@test size(k(m1)) == (4, 4)

A1 = ones(4, 4)
kernelmatrix!(A1, k, m1, m2)
@test A1 ≈ kernelmatrix(k, m1, m2) atol=1e-5

A2 = ones(4, 4)
kernelmatrix!(A2, k, m1)
@test A2 ≈ kernelmatrix(k, m1) atol=1e-5

@test size(kerneldiagmatrix(k, m1)) == (4,)
@test kerneldiagmatrix(k, m1) == ones(4)
A3 = ones(4)
kerneldiagmatrix!(A3, k, m1)
@test A3 == kerneldiagmatrix(k, m1)

@test_throws ErrorException PiecewisePolynomialKernel{4}(maha)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using KernelFunctions: metric
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "matern.jl"))
include(joinpath("kernels", "polynomial.jl"))
include(joinpath("kernels", "piecewisepolynomial.jl"))
include(joinpath("kernels", "rationalquad.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "transformedkernel.jl"))
Expand Down