Skip to content

Add gaborkernel and simplify kernels with IdentityTransforms #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.9.6"
version = "0.9.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ FBMKernel
### Gabor Kernel

```@docs
gaborkernel
GaborKernel
```

Expand Down
3 changes: 2 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export Transform,

export NystromFact, nystrom

export gaborkernel
export spectral_mixture_kernel, spectral_mixture_product_kernel

export ColVecs, RowVecs
Expand Down Expand Up @@ -73,6 +74,7 @@ include(joinpath("transform", "functiontransform.jl"))
include(joinpath("transform", "selecttransform.jl"))
include(joinpath("transform", "chaintransform.jl"))
include(joinpath("transform", "periodic_transform.jl"))
include(joinpath("kernels", "transformedkernel.jl"))

include(joinpath("basekernels", "constant.jl"))
include(joinpath("basekernels", "cosine.jl"))
Expand All @@ -89,7 +91,6 @@ include(joinpath("basekernels", "rational.jl"))
include(joinpath("basekernels", "sm.jl"))
include(joinpath("basekernels", "wiener.jl"))

include(joinpath("kernels", "transformedkernel.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "normalizedkernel.jl"))
include(joinpath("matrix", "kernelmatrix.jl"))
Expand Down
45 changes: 40 additions & 5 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
"""
gaborkernel(;
sqexponential_transform=IdentityTransform(), cosine_tranform=IdentityTransform()
)

Construct a Gabor kernel with transformations `sqexponential_transform` and
`cosine_transform` of the inputs of the underlying squared exponential and cosine kernel,
respectively.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the Gabor kernel with transformations ``f``
and ``g`` of the inputs to the squared exponential and cosine kernel, respectively,
is defined as
```math
k(x, x'; f, g) = \\exp\\bigg(- \\frac{\\| f(x) - f(x')\\|_2^2}{2}\\bigg)
\\cos\\big(\\pi \\|g(x) - g(x')\\|_2 \\big).
```
"""
function gaborkernel(;
sqexponential_transform=IdentityTransform(), cosine_transform=IdentityTransform()
)
return (SqExponentialKernel() ∘ sqexponential_transform) *
(CosineKernel() ∘ cosine_transform)
end

# everything below will be removed
"""
GaborKernel(; ell::Real=1.0, p::Real=1.0)

Expand All @@ -11,11 +38,20 @@ and period ``p_i > 0`` is defined as
k(x, x'; l, p) = \\exp\\bigg(- \\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{2l_i^2}\\bigg)
\\cos\\bigg(\\pi \\bigg(\\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{p_i^2} \\bigg)^{1/2}\\bigg).
```

!!! note
`GaborKernel` is deprecated and will be removed. Gabor kernels should be
constructed with [`gaborkernel`](@ref) instead.
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
Base.depwarn(
"`GaborKernel` is deprecated and will be removed. Gabor kernels should be " *
"constructed with `gaborkernel` instead.",
:GaborKernel,
)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform)
Expand All @@ -31,19 +67,18 @@ _lengthscale_transform(::Nothing) = IdentityTransform()
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(::IdentityTransform) = 1
_lengthscale(x) = 1
_lengthscale(k::TransformedKernel) = _lengthscale(k.transform)
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
ell_transform = k.kernel.kernels[1].transform
return _lengthscale(ell_transform)
return _lengthscale(k.kernel.kernels[1])
elseif v == :p
p_transform = k.kernel.kernels[2].transform
return _lengthscale(p_transform)
return _lengthscale(k.kernel.kernels[2])
else
error("Invalid Property")
end
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ See also: [`TransformedKernel`](@ref)
Base.:∘(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform ∘ t)

# Simplify kernels with identity transformation of the inputs
Base.:∘(k::Kernel, ::IdentityTransform) = k
Base.:∘(k::TransformedKernel, ::IdentityTransform) = k

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

function printshifted(io::IO, κ::TransformedKernel, shift::Int)
Expand Down
55 changes: 40 additions & 15 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,50 @@
@testset "Gabor" begin
v1 = rand(3)
v2 = rand(3)
ell = abs(rand())
p = abs(rand())
k = GaborKernel(; ell=ell, p=p)
@test k.ell ≈ ell atol = 1e-5
@test k.p ≈ p atol = 1e-5
ell = rand()
p = rand()
k = gaborkernel(;
sqexponential_transform=ScaleTransform(inv(ell)),
cosine_transform=ScaleTransform(inv(p)),
)
@test k isa KernelProduct{
<:Tuple{
TransformedKernel{SqExponentialKernel,<:ScaleTransform},
TransformedKernel{CosineKernel,<:ScaleTransform},
},
}
@test k.kernels[1].transform.s[1] == inv(ell)
@test k.kernels[2].transform.s[1] == inv(p)

k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) ≈ k_manual atol = 1e-5
k_manual = exp(-sqeuclidean(v1, v2) / (2 * ell^2)) * cospi(euclidean(v1, v2) / p)
@test k_manual ≈ k(v1, v2) atol = 1e-5

lhs_manual = (SqExponentialKernel() ∘ ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ∘ ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) ≈ lhs_manual * rhs_manual atol = 1e-5
lhs_manual = (SqExponentialKernel() ∘ ScaleTransform(1 / ell))(v1, v2)
rhs_manual = (CosineKernel() ∘ ScaleTransform(1 / p))(v1, v2)
@test lhs_manual * rhs_manual ≈ k(v1, v2) atol = 1e-5

k = GaborKernel()
@test k.ell ≈ 1.0 atol = 1e-5
@test k.p ≈ 1.0 atol = 1e-5
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"
@test gaborkernel() isa KernelProduct{Tuple{SqExponentialKernel,CosineKernel}}

test_interface(k, Vector{Float64})
test_ADs(
x -> gaborkernel(;
sqexponential_transform=ScaleTransform(x[1]),
cosine_transform=ScaleTransform(x[2]),
),
[ell, p],
)

# deprecated `GaborKernel`
k2 = @test_deprecated GaborKernel(; ell=ell, p=p)
@test k2.ell ≈ ell atol = 1e-5
@test k2.p ≈ p atol = 1e-5
@test k2(v1, v2) ≈ k(v1, v2)

k3 = @test_deprecated GaborKernel()
@test k3.ell ≈ 1.0 atol = 1e-5
@test k3.p ≈ 1.0 atol = 1e-5
@test repr(k3) == "Gabor Kernel (ell = 1, p = 1)"

test_interface(k3, Vector{Float64})

test_ADs(x -> GaborKernel(; ell=x[1], p=x[2]), [ell, p]; ADs=[:Zygote])

Expand Down
4 changes: 4 additions & 0 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
@test k ∘ IdentityTransform() === k

kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt ∘ IdentityTransform() === kt
@test ktard ∘ IdentityTransform() === ktard
@test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2)
@test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2)
Expand Down