Skip to content

Commit a99a3e7

Browse files
authored
Add gaborkernel and simplify kernels with IdentityTransforms (#285)
1 parent 3d8b245 commit a99a3e7

File tree

7 files changed

+92
-22
lines changed

7 files changed

+92
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.6"
3+
version = "0.9.7"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/kernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ FBMKernel
5050
### Gabor Kernel
5151

5252
```@docs
53+
gaborkernel
5354
GaborKernel
5455
```
5556

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export Transform,
3131

3232
export NystromFact, nystrom
3333

34+
export gaborkernel
3435
export spectral_mixture_kernel, spectral_mixture_product_kernel
3536

3637
export ColVecs, RowVecs
@@ -73,6 +74,7 @@ include(joinpath("transform", "functiontransform.jl"))
7374
include(joinpath("transform", "selecttransform.jl"))
7475
include(joinpath("transform", "chaintransform.jl"))
7576
include(joinpath("transform", "periodic_transform.jl"))
77+
include(joinpath("kernels", "transformedkernel.jl"))
7678

7779
include(joinpath("basekernels", "constant.jl"))
7880
include(joinpath("basekernels", "cosine.jl"))
@@ -89,7 +91,6 @@ include(joinpath("basekernels", "rational.jl"))
8991
include(joinpath("basekernels", "sm.jl"))
9092
include(joinpath("basekernels", "wiener.jl"))
9193

92-
include(joinpath("kernels", "transformedkernel.jl"))
9394
include(joinpath("kernels", "scaledkernel.jl"))
9495
include(joinpath("kernels", "normalizedkernel.jl"))
9596
include(joinpath("matrix", "kernelmatrix.jl"))

src/basekernels/gabor.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
"""
2+
gaborkernel(;
3+
sqexponential_transform=IdentityTransform(), cosine_tranform=IdentityTransform()
4+
)
5+
6+
Construct a Gabor kernel with transformations `sqexponential_transform` and
7+
`cosine_transform` of the inputs of the underlying squared exponential and cosine kernel,
8+
respectively.
9+
10+
# Definition
11+
12+
For inputs ``x, x' \\in \\mathbb{R}^d``, the Gabor kernel with transformations ``f``
13+
and ``g`` of the inputs to the squared exponential and cosine kernel, respectively,
14+
is defined as
15+
```math
16+
k(x, x'; f, g) = \\exp\\bigg(- \\frac{\\| f(x) - f(x')\\|_2^2}{2}\\bigg)
17+
\\cos\\big(\\pi \\|g(x) - g(x')\\|_2 \\big).
18+
```
19+
"""
20+
function gaborkernel(;
21+
sqexponential_transform=IdentityTransform(), cosine_transform=IdentityTransform()
22+
)
23+
return (SqExponentialKernel() sqexponential_transform) *
24+
(CosineKernel() cosine_transform)
25+
end
26+
27+
# everything below will be removed
128
"""
229
GaborKernel(; ell::Real=1.0, p::Real=1.0)
330
@@ -11,11 +38,20 @@ and period ``p_i > 0`` is defined as
1138
k(x, x'; l, p) = \\exp\\bigg(- \\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{2l_i^2}\\bigg)
1239
\\cos\\bigg(\\pi \\bigg(\\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{p_i^2} \\bigg)^{1/2}\\bigg).
1340
```
41+
42+
!!! note
43+
`GaborKernel` is deprecated and will be removed. Gabor kernels should be
44+
constructed with [`gaborkernel`](@ref) instead.
1445
"""
1546
struct GaborKernel{K<:Kernel} <: Kernel
1647
kernel::K
1748

1849
function GaborKernel(; ell=nothing, p=nothing)
50+
Base.depwarn(
51+
"`GaborKernel` is deprecated and will be removed. Gabor kernels should be " *
52+
"constructed with `gaborkernel` instead.",
53+
:GaborKernel,
54+
)
1955
ell_transform = _lengthscale_transform(ell)
2056
p_transform = _lengthscale_transform(p)
2157
k = (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
@@ -31,19 +67,18 @@ _lengthscale_transform(::Nothing) = IdentityTransform()
3167
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
3268
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))
3369

34-
_lengthscale(::IdentityTransform) = 1
70+
_lengthscale(x) = 1
71+
_lengthscale(k::TransformedKernel) = _lengthscale(k.transform)
3572
_lengthscale(t::ScaleTransform) = inv(first(t.s))
3673
_lengthscale(t::ARDTransform) = map(inv, t.v)
3774

3875
function Base.getproperty(k::GaborKernel, v::Symbol)
3976
if v == :kernel
4077
return getfield(k, v)
4178
elseif v == :ell
42-
ell_transform = k.kernel.kernels[1].transform
43-
return _lengthscale(ell_transform)
79+
return _lengthscale(k.kernel.kernels[1])
4480
elseif v == :p
45-
p_transform = k.kernel.kernels[2].transform
46-
return _lengthscale(p_transform)
81+
return _lengthscale(k.kernel.kernels[2])
4782
else
4883
error("Invalid Property")
4984
end

src/kernels/transformedkernel.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ See also: [`TransformedKernel`](@ref)
6868
Base.:(k::Kernel, t::Transform) = TransformedKernel(k, t)
6969
Base.:(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform t)
7070

71+
# Simplify kernels with identity transformation of the inputs
72+
Base.:(k::Kernel, ::IdentityTransform) = k
73+
Base.:(k::TransformedKernel, ::IdentityTransform) = k
74+
7175
Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)
7276

7377
function printshifted(io::IO, κ::TransformedKernel, shift::Int)

test/basekernels/gabor.jl

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,50 @@
11
@testset "Gabor" begin
22
v1 = rand(3)
33
v2 = rand(3)
4-
ell = abs(rand())
5-
p = abs(rand())
6-
k = GaborKernel(; ell=ell, p=p)
7-
@test k.ell ell atol = 1e-5
8-
@test k.p p atol = 1e-5
4+
ell = rand()
5+
p = rand()
6+
k = gaborkernel(;
7+
sqexponential_transform=ScaleTransform(inv(ell)),
8+
cosine_transform=ScaleTransform(inv(p)),
9+
)
10+
@test k isa KernelProduct{
11+
<:Tuple{
12+
TransformedKernel{SqExponentialKernel,<:ScaleTransform},
13+
TransformedKernel{CosineKernel,<:ScaleTransform},
14+
},
15+
}
16+
@test k.kernels[1].transform.s[1] == inv(ell)
17+
@test k.kernels[2].transform.s[1] == inv(p)
918

10-
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
11-
@test k(v1, v2) k_manual atol = 1e-5
19+
k_manual = exp(-sqeuclidean(v1, v2) / (2 * ell^2)) * cospi(euclidean(v1, v2) / p)
20+
@test k_manual k(v1, v2) atol = 1e-5
1221

13-
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
14-
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
15-
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5
22+
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / ell))(v1, v2)
23+
rhs_manual = (CosineKernel() ScaleTransform(1 / p))(v1, v2)
24+
@test lhs_manual * rhs_manual k(v1, v2) atol = 1e-5
1625

17-
k = GaborKernel()
18-
@test k.ell 1.0 atol = 1e-5
19-
@test k.p 1.0 atol = 1e-5
20-
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"
26+
@test gaborkernel() isa KernelProduct{Tuple{SqExponentialKernel,CosineKernel}}
2127

22-
test_interface(k, Vector{Float64})
28+
test_ADs(
29+
x -> gaborkernel(;
30+
sqexponential_transform=ScaleTransform(x[1]),
31+
cosine_transform=ScaleTransform(x[2]),
32+
),
33+
[ell, p],
34+
)
35+
36+
# deprecated `GaborKernel`
37+
k2 = @test_deprecated GaborKernel(; ell=ell, p=p)
38+
@test k2.ell ell atol = 1e-5
39+
@test k2.p p atol = 1e-5
40+
@test k2(v1, v2) k(v1, v2)
41+
42+
k3 = @test_deprecated GaborKernel()
43+
@test k3.ell 1.0 atol = 1e-5
44+
@test k3.p 1.0 atol = 1e-5
45+
@test repr(k3) == "Gabor Kernel (ell = 1, p = 1)"
46+
47+
test_interface(k3, Vector{Float64})
2348

2449
test_ADs(x -> GaborKernel(; ell=x[1], p=x[2]), [ell, p]; ADs=[:Zygote])
2550

test/kernels/transformedkernel.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
v = rand(rng, 3)
99
P = rand(rng, 3, 2)
1010
k = SqExponentialKernel()
11+
@test k IdentityTransform() === k
12+
1113
kt = TransformedKernel(k, ScaleTransform(s))
1214
ktard = TransformedKernel(k, ARDTransform(v))
15+
@test kt IdentityTransform() === kt
16+
@test ktard IdentityTransform() === ktard
1317
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
1418
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
1519
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)

0 commit comments

Comments
 (0)