Skip to content

Commit a827302

Browse files
authored
Add metric field (#286)
1 parent df67ab7 commit a827302

24 files changed

+295
-322
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.7"
3+
version = "0.10.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/kernels.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ FBMKernel
5151

5252
```@docs
5353
gaborkernel
54-
GaborKernel
5554
```
5655

5756
### Matérn Kernels

src/KernelFunctions.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag!
4-
export transform
54
export duplicate, set! # Helpers
65

76
export Kernel, MOKernel
@@ -14,7 +13,7 @@ export FBMKernel
1413
export MaternKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel
1514
export LinearKernel, PolynomialKernel
1615
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
17-
export GaborKernel, PiecewisePolynomialKernel
16+
export PiecewisePolynomialKernel
1817
export PeriodicKernel, NeuralNetworkKernel
1918
export KernelSum, KernelProduct, KernelTensorProduct
2019
export TransformedKernel, ScaledKernel, NormalizedKernel
@@ -112,8 +111,6 @@ include("zygoterules.jl")
112111

113112
include("test_utils.jl")
114113

115-
include("deprecations.jl")
116-
117114
function __init__()
118115
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
119116
include(joinpath("matrix", "kernelkroneckermat.jl"))

src/basekernels/cosine.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
"""
2-
CosineKernel()
2+
CosineKernel(; metric=Euclidean())
33
4-
Cosine kernel.
4+
Cosine kernel with respect to the `metric`.
55
66
# Definition
77
8-
For inputs ``x, x' \\in \\mathbb{R}^d``, the cosine kernel is defined as
8+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the cosine kernel is defined as
99
```math
10-
k(x, x') = \\cos(\\pi \\|x-x'\\|_2).
10+
k(x, x') = \\cos(\\pi d(x, x')).
1111
```
12+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
1213
"""
13-
struct CosineKernel <: SimpleKernel end
14+
struct CosineKernel{M} <: SimpleKernel
15+
metric::M
16+
17+
function CosineKernel(; metric=Euclidean())
18+
return new{typeof(metric)}(metric)
19+
end
20+
end
1421

1522
kappa(::CosineKernel, d::Real) = cospi(d)
1623

17-
metric(::CosineKernel) = Euclidean()
24+
metric(k::CosineKernel) = k.metric
1825

19-
Base.show(io::IO, ::CosineKernel) = print(io, "Cosine Kernel")
26+
Base.show(io::IO, k::CosineKernel) = print(io, "Cosine Kernel (metric = ", k.metric, ")")

src/basekernels/exponential.jl

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
"""
2-
SqExponentialKernel()
2+
SqExponentialKernel(; metric=Euclidean())
33
4-
Squared exponential kernel.
4+
Squared exponential kernel with respect to the `metric`.
55
66
# Definition
77
8-
For inputs ``x, x' \\in \\mathbb{R}^d``, the squared exponential kernel is defined as
8+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the squared exponential kernel is
9+
defined as
910
```math
10-
k(x, x') = \\exp\\bigg(- \\frac{\\|x - x'\\|_2^2}{2}\\bigg).
11+
k(x, x') = \\exp\\bigg(- \\frac{d(x, x')^2}{2}\\bigg).
1112
```
13+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
1214
1315
See also: [`GammaExponentialKernel`](@ref)
1416
"""
15-
struct SqExponentialKernel <: SimpleKernel end
17+
struct SqExponentialKernel{M} <: SimpleKernel
18+
metric::M
1619

17-
kappa(::SqExponentialKernel, d²::Real) = exp(-/ 2)
20+
function SqExponentialKernel(; metric=Euclidean())
21+
return new{typeof(metric)}(metric)
22+
end
23+
end
24+
25+
kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2)
26+
kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-/ 2)
1827

19-
metric(::SqExponentialKernel) = SqEuclidean()
28+
metric(k::SqExponentialKernel) = k.metric
29+
metric(::SqExponentialKernel{<:Euclidean}) = SqEuclidean()
2030

2131
iskroncompatible(::SqExponentialKernel) = true
2232

23-
Base.show(io::IO, ::SqExponentialKernel) = print(io, "Squared Exponential Kernel")
33+
function Base.show(io::IO, k::SqExponentialKernel)
34+
return print(io, "Squared Exponential Kernel (metric = ", k.metric, ")")
35+
end
2436

2537
## Aliases ##
2638

@@ -46,28 +58,37 @@ Alias of [`SqExponentialKernel`](@ref).
4658
const SEKernel = SqExponentialKernel
4759

4860
"""
49-
ExponentialKernel()
61+
ExponentialKernel(; metric=Euclidean())
5062
51-
Exponential kernel.
63+
Exponential kernel with respect to the `metric`.
5264
5365
# Definition
5466
55-
For inputs ``x, x' \\in \\mathbb{R}^d``, the exponential kernel is defined as
67+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the exponential kernel is defined as
5668
```math
57-
k(x, x') = \\exp\\big(- \\|x - x'\\|_2\\big).
69+
k(x, x') = \\exp\\big(- d(x, x')\\big).
5870
```
71+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
5972
6073
See also: [`GammaExponentialKernel`](@ref)
6174
"""
62-
struct ExponentialKernel <: SimpleKernel end
75+
struct ExponentialKernel{M} <: SimpleKernel
76+
metric::M
77+
78+
function ExponentialKernel(; metric=Euclidean())
79+
return new{typeof(metric)}(metric)
80+
end
81+
end
6382

6483
kappa(::ExponentialKernel, d::Real) = exp(-d)
6584

66-
metric(::ExponentialKernel) = Euclidean()
85+
metric(k::ExponentialKernel) = k.metric
6786

6887
iskroncompatible(::ExponentialKernel) = true
6988

70-
Base.show(io::IO, ::ExponentialKernel) = print(io, "Exponential Kernel")
89+
function Base.show(io::IO, k::ExponentialKernel)
90+
return print(io, "Exponential Kernel (metric = ", k.metric, ")")
91+
end
7192

7293
## Aliases ##
7394

@@ -86,53 +107,44 @@ Alias of [`ExponentialKernel`](@ref).
86107
const Matern12Kernel = ExponentialKernel
87108

88109
"""
89-
GammaExponentialKernel(; γ::Real=2.0)
110+
GammaExponentialKernel(; γ::Real=1.0, metric=Euclidean())
90111
91-
γ-exponential kernel with parameter `γ`.
112+
γ-exponential kernel with respect to the `metric` and with parameter `γ`.
92113
93114
# Definition
94115
95-
For inputs ``x, x' \\in \\mathbb{R}^d``, the γ-exponential kernel[^RW] with parameter
96-
``\\gamma \\in (0, 2]`` is defined as
116+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the γ-exponential kernel[^RW] with
117+
parameter ``\\gamma \\in (0, 2]``
118+
is defined as
97119
```math
98-
k(x, x'; \\gamma) = \\exp\\big(- \\|x - x'\\|_2^{\\gamma}\\big).
120+
k(x, x'; \\gamma) = \\exp\\big(- d(x, x')^{\\gamma}\\big).
99121
```
100-
101-
!!! warning
102-
The default value of parameter `γ` will be changed to `1.0` in the next breaking release
103-
of KernelFunctions.
122+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
104123
105124
See also: [`ExponentialKernel`](@ref), [`SqExponentialKernel`](@ref)
106125
107126
[^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning.
108127
"""
109-
struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
128+
struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel
110129
γ::Vector{Tγ}
111-
# function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma)
112-
function GammaExponentialKernel(; gamma=nothing, γ=gamma)
113-
γ2 = if γ === nothing
114-
Base.depwarn(
115-
"the default value of parameter `γ` of the `GammaExponentialKernel` will " *
116-
"be changed to `1.0` in the next breaking release of KernelFunctions",
117-
:GammaExponentialKernel,
118-
)
119-
2.0
120-
else
121-
γ
122-
end
123-
@check_args(GammaExponentialKernel, γ2, zero(γ2) < γ2 2, "γ ∈ (0, 2]")
124-
return new{typeof(γ2)}([γ2])
130+
metric::M
131+
132+
function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean())
133+
@check_args(GammaExponentialKernel, γ, zero(γ) < γ 2, "γ ∈ (0, 2]")
134+
return new{typeof(γ),typeof(metric)}([γ], metric)
125135
end
126136
end
127137

128138
@functor GammaExponentialKernel
129139

130140
kappa::GammaExponentialKernel, d::Real) = exp(-d^first.γ))
131141

132-
metric(::GammaExponentialKernel) = Euclidean()
142+
metric(k::GammaExponentialKernel) = k.metric
133143

134144
iskroncompatible(::GammaExponentialKernel) = true
135145

136146
function Base.show(io::IO, κ::GammaExponentialKernel)
137-
return print(io, "Gamma Exponential Kernel (γ = ", first.γ), ")")
147+
return print(
148+
io, "Gamma Exponential Kernel (γ = ", first.γ), ", metric = ", κ.metric, ")"
149+
)
138150
end

src/basekernels/gabor.jl

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,79 +23,3 @@ function gaborkernel(;
2323
return (SqExponentialKernel() sqexponential_transform) *
2424
(CosineKernel() cosine_transform)
2525
end
26-
27-
# everything below will be removed
28-
"""
29-
GaborKernel(; ell::Real=1.0, p::Real=1.0)
30-
31-
Gabor kernel with lengthscale `ell` and period `p`.
32-
33-
# Definition
34-
35-
For inputs ``x, x' \\in \\mathbb{R}^d``, the Gabor kernel with lengthscale ``l_i > 0``
36-
and period ``p_i > 0`` is defined as
37-
```math
38-
k(x, x'; l, p) = \\exp\\bigg(- \\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{2l_i^2}\\bigg)
39-
\\cos\\bigg(\\pi \\bigg(\\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{p_i^2} \\bigg)^{1/2}\\bigg).
40-
```
41-
42-
!!! note
43-
`GaborKernel` is deprecated and will be removed. Gabor kernels should be
44-
constructed with [`gaborkernel`](@ref) instead.
45-
"""
46-
struct GaborKernel{K<:Kernel} <: Kernel
47-
kernel::K
48-
49-
function GaborKernel(; ell=nothing, p=nothing)
50-
Base.depwarn(
51-
"`GaborKernel` is deprecated and will be removed. Gabor kernels should be " *
52-
"constructed with `gaborkernel` instead.",
53-
:GaborKernel,
54-
)
55-
ell_transform = _lengthscale_transform(ell)
56-
p_transform = _lengthscale_transform(p)
57-
k = (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
58-
return new{typeof(k)}(k)
59-
end
60-
end
61-
62-
@functor GaborKernel
63-
64-
::GaborKernel)(x, y) = κ.kernel(x, y)
65-
66-
_lengthscale_transform(::Nothing) = IdentityTransform()
67-
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
68-
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))
69-
70-
_lengthscale(x) = 1
71-
_lengthscale(k::TransformedKernel) = _lengthscale(k.transform)
72-
_lengthscale(t::ScaleTransform) = inv(first(t.s))
73-
_lengthscale(t::ARDTransform) = map(inv, t.v)
74-
75-
function Base.getproperty(k::GaborKernel, v::Symbol)
76-
if v == :kernel
77-
return getfield(k, v)
78-
elseif v == :ell
79-
return _lengthscale(k.kernel.kernels[1])
80-
elseif v == :p
81-
return _lengthscale(k.kernel.kernels[2])
82-
else
83-
error("Invalid Property")
84-
end
85-
end
86-
87-
function Base.show(io::IO, κ::GaborKernel)
88-
return print(io, "Gabor Kernel (ell = ", κ.ell, ", p = ", κ.p, ")")
89-
end
90-
91-
kernelmatrix::GaborKernel, x::AbstractVector) = kernelmatrix.kernel, x)
92-
93-
function kernelmatrix::GaborKernel, x::AbstractVector, y::AbstractVector)
94-
return kernelmatrix.kernel, x, y)
95-
end
96-
97-
kernelmatrix_diag::GaborKernel, x::AbstractVector) = kernelmatrix_diag.kernel, x)
98-
99-
function kernelmatrix_diag::GaborKernel, x::AbstractVector, y::AbstractVector)
100-
return kernelmatrix_diag.kernel, x, y)
101-
end

0 commit comments

Comments
 (0)