Skip to content

Commit 0065912

Browse files
committed
Add RationalKernel and rename GammaRationalQuadraticKernel
1 parent 0ff8761 commit 0065912

File tree

6 files changed

+124
-40
lines changed

6 files changed

+124
-40
lines changed

docs/src/kernels.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,12 @@ LinearKernel
8888
PolynomialKernel
8989
```
9090

91-
### Rational Quadratic Kernels
91+
### Rational Kernels
9292

9393
```@docs
94+
RationalKernel
9495
RationalQuadraticKernel
95-
GammaRationalQuadraticKernel
96+
GammaRationalKernel
9697
```
9798

9899
### Spectral Mixture Kernels

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ export ExponentiatedKernel
2626
export FBMKernel
2727
export MaternKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel
2828
export LinearKernel, PolynomialKernel
29-
export RationalQuadraticKernel, GammaRationalQuadraticKernel
29+
export RationalKernel, RationalQuadraticKernel
30+
export GammaRationalKernel, GammaRationalQuadraticKernel
3031
export GaborKernel, PiecewisePolynomialKernel
3132
export PeriodicKernel, NeuralNetworkKernel
3233
export KernelSum, KernelProduct

src/basekernels/rationalquad.jl

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,40 @@
1+
"""
2+
RationalKernel(; α::Real=2.0)
3+
4+
Rational kernel with shape parameter `α`.
5+
6+
# Definition
7+
8+
For inputs ``x, x' \\in \\mathbb{R}^d``, the rational kernel with shape parameter
9+
``\\alpha > 0`` is defined as
10+
```math
11+
k(x, x'; \\alpha) = \\bigg(1 + \\frac{\\|x - x'\\|_2}{\\alpha}\\bigg)^{-\\alpha}.
12+
```
13+
14+
The [`ExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\infty``.
15+
16+
See also: [`GammaRationalKernel`](@ref)
17+
"""
18+
struct RationalKernel{Tα<:Real} <: SimpleKernel
19+
α::Vector{Tα}
20+
function RationalKernel(; alpha::T=2.0, α::T=alpha) where {T}
21+
@check_args(RationalKernel, α, α > zero(T), "α > 0")
22+
return new{T}([α])
23+
end
24+
end
25+
26+
@functor RationalKernel
27+
28+
function kappa::RationalKernel, d::Real)
29+
return (one(d) + d / first.α))^(-first.α))
30+
end
31+
32+
metric(::RationalKernel) = Euclidean()
33+
34+
function Base.show(io::IO, κ::RationalKernel)
35+
return print(io, "Rational Kernel (α = $(first.α)))")
36+
end
37+
138
"""
239
RationalQuadraticKernel(; α::Real=2.0)
340
@@ -13,7 +50,7 @@ k(x, x'; \\alpha) = \\bigg(1 + \\frac{\\|x - x'\\|_2^2}{2\\alpha}\\bigg)^{-\\alp
1350
1451
The [`SqExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\infty``.
1552
16-
See also: [`GammaRationalQuadraticKernel`](@ref)
53+
See also: [`GammaRationalKernel`](@ref)
1754
"""
1855
struct RationalQuadraticKernel{Tα<:Real} <: SimpleKernel
1956
α::Vector{Tα}
@@ -36,44 +73,42 @@ function Base.show(io::IO, κ::RationalQuadraticKernel)
3673
end
3774

3875
"""
39-
GammaRationalQuadraticKernel(; α::Real=2.0, γ::Real=2.0)
76+
GammaRationalKernel(; α::Real=2.0, γ::Real=2.0)
4077
41-
γ-rational-quadratic kernel with shape parameters `α` and `γ`.
78+
γ-rational kernel with shape parameters `α` and `γ`.
4279
4380
# Definition
4481
45-
For inputs ``x, x' \\in \\mathbb{R}^d``, the γ-rational-quadratic kernel with shape
82+
For inputs ``x, x' \\in \\mathbb{R}^d``, the γ-rational kernel with shape
4683
parameters ``\\alpha > 0`` and ``\\gamma \\in (0, 2]`` is defined as
4784
```math
4885
k(x, x'; \\alpha, \\gamma) = \\bigg(1 + \\frac{\\|x - x'\\|_2^{\\gamma}}{\\alpha}\\bigg)^{-\\alpha}.
4986
```
5087
5188
The [`GammaExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\infty``.
5289
53-
See also: [`RationalQuadraticKernel`](@ref)
90+
See also: [`RationalKernel`](@ref), [`RationalQuadraticKernel`](@ref)
5491
"""
55-
struct GammaRationalQuadraticKernel{Tα<:Real,Tγ<:Real} <: SimpleKernel
92+
struct GammaRationalKernel{Tα<:Real,Tγ<:Real} <: SimpleKernel
5693
α::Vector{Tα}
5794
γ::Vector{Tγ}
58-
function GammaRationalQuadraticKernel(;
95+
function GammaRationalKernel(;
5996
alpha::Real=2.0, gamma::Real=2.0, α::Real=alpha, γ::Real=gamma
6097
)
61-
@check_args(GammaRationalQuadraticKernel, α, α > zero(α), "α > 0")
62-
@check_args(GammaRationalQuadraticKernel, γ, zero(γ) < γ 2, "γ ∈ (0, 2]")
98+
@check_args(GammaRationalKernel, α, α > zero(α), "α > 0")
99+
@check_args(GammaRationalKernel, γ, zero(γ) < γ 2, "γ ∈ (0, 2]")
63100
return new{typeof(α),typeof(γ)}([α], [γ])
64101
end
65102
end
66103

67-
@functor GammaRationalQuadraticKernel
104+
@functor GammaRationalKernel
68105

69-
function kappa::GammaRationalQuadraticKernel, d::Real)
106+
function kappa::GammaRationalKernel, d::Real)
70107
return (one(d) + d^first.γ) / first.α))^(-first.α))
71108
end
72109

73-
metric(::GammaRationalQuadraticKernel) = Euclidean()
110+
metric(::GammaRationalKernel) = Euclidean()
74111

75-
function Base.show(io::IO, κ::GammaRationalQuadraticKernel)
76-
return print(
77-
io, "Gamma Rational Quadratic Kernel (α = $(first.α)), γ = $(first.γ)))"
78-
)
112+
function Base.show(io::IO, κ::GammaRationalKernel)
113+
return print(io, "Gamma Rational Kernel (α = $(first.α)), γ = $(first.γ)))")
79114
end

src/deprecated.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
88
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
99
)
10+
11+
# TODO: remove in next breaking release
12+
const GammaRationalQuadraticKernel = GammaRationalKernel

test/basekernels/rationalquad.jl

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,31 @@
44
v1 = rand(rng, 3)
55
v2 = rand(rng, 3)
66

7+
@testset "RationalKernel" begin
8+
α = rand()
9+
k = RationalKernel(; α=α)
10+
11+
@testset "RationalKernel ≈ Exponential for large α" begin
12+
@test isapprox(
13+
RationalKernel(; α=1e9)(v1, v2),
14+
ExponentialKernel()(v1, v2);
15+
atol=1e-6,
16+
rtol=1e-6,
17+
)
18+
end
19+
20+
@test metric(RationalKernel()) == Euclidean()
21+
@test metric(RationalKernel(; α=α)) == Euclidean()
22+
@test repr(k) == "Rational Kernel (α = $(α))"
23+
24+
# Standardised tests.
25+
TestUtils.test_interface(k, Float64)
26+
test_ADs(x -> RationalKernel(; alpha=exp(x[1])), [α])
27+
test_params(k, ([α],))
28+
end
29+
730
@testset "RationalQuadraticKernel" begin
8-
α = 2.0
31+
α = rand()
932
k = RationalQuadraticKernel(; α=α)
1033

1134
@testset "RQ ≈ EQ for large α" begin
@@ -18,74 +41,91 @@
1841
end
1942

2043
@test metric(RationalQuadraticKernel()) == SqEuclidean()
21-
@test metric(RationalQuadraticKernel(; α=2.0)) == SqEuclidean()
44+
@test metric(RationalQuadraticKernel(; α=α)) == SqEuclidean()
2245
@test repr(k) == "Rational Quadratic Kernel (α = $(α))"
2346

2447
# Standardised tests.
2548
TestUtils.test_interface(k, Float64)
26-
test_ADs(x -> RationalQuadraticKernel(; alpha=x[1]), [α])
49+
test_ADs(x -> RationalQuadraticKernel(; alpha=exp(x[1])), [α])
2750
test_params(k, ([α],))
2851
end
2952

30-
@testset "GammaRationalQuadraticKernel" begin
31-
k = GammaRationalQuadraticKernel()
53+
@testset "GammaRationalKernel" begin
54+
k = GammaRationalKernel()
3255

33-
@test repr(k) == "Gamma Rational Quadratic Kernel (α = 2.0, γ = 2.0)"
56+
@test repr(k) == "Gamma Rational Kernel (α = 2.0, γ = 2.0)"
3457

35-
@testset "Default GammaRQ ≈ RQ with rescaled inputs" begin
58+
@testset "Default GammaRational ≈ RQ with rescaled inputs" begin
3659
@test isapprox(
37-
GammaRationalQuadraticKernel()(v1 ./ sqrt(2), v2 ./ sqrt(2)),
60+
GammaRationalKernel()(v1 ./ sqrt(2), v2 ./ sqrt(2)),
3861
RationalQuadraticKernel()(v1, v2),
3962
)
40-
a = 1.0 + rand()
63+
a = 1 + rand()
4164
@test isapprox(
42-
GammaRationalQuadraticKernel(; α=a)(v1 ./ sqrt(2), v2 ./ sqrt(2)),
65+
GammaRationalKernel(; α=a)(v1 ./ sqrt(2), v2 ./ sqrt(2)),
4366
RationalQuadraticKernel(; α=a)(v1, v2),
4467
)
4568
end
4669

47-
@testset "GammaRQ ≈ EQ for large α with rescaled inputs" begin
70+
@testset "Default GammaRational ≈ EQ for large α with rescaled inputs" begin
4871
v1 = randn(2)
4972
v2 = randn(2)
5073
@test isapprox(
51-
GammaRationalQuadraticKernel(; α=1e9)(v1 ./ sqrt(2), v2 ./ sqrt(2)),
74+
GammaRationalKernel(; α=1e9)(v1 ./ sqrt(2), v2 ./ sqrt(2)),
5275
SqExponentialKernel()(v1, v2);
5376
atol=1e-6,
5477
rtol=1e-6,
5578
)
5679
end
5780

58-
@testset "GammaRQ(γ=1) ≈ Exponential with rescaled inputs for large α" begin
81+
@testset "GammaRational(γ=1) ≈ Rational" begin
82+
@test isapprox(GammaRationalKernel(; γ=1.0)(v1, v2), RationalKernel()(v1, v2))
83+
a = 1 + rand()
84+
@test isapprox(
85+
GammaRationalKernel(; γ=1.0, α=a)(v1, v2), RationalKernel(; α=a)(v1, v2)
86+
)
87+
end
88+
89+
@testset "GammaRational(γ=1) ≈ Exponential with rescaled inputs for large α" begin
5990
v1 = randn(4)
6091
v2 = randn(4)
6192
@test isapprox(
62-
GammaRationalQuadraticKernel(; α=1e9, γ=1.0)(v1, v2),
93+
GammaRationalKernel(; α=1e9, γ=1.0)(v1, v2),
6394
ExponentialKernel()(v1, v2);
6495
atol=1e-6,
6596
rtol=1e-6,
6697
)
6798
end
6899

69-
@testset "GammaRQ ≈ GammaExponential for same γ and large α" begin
100+
@testset "GammaRational ≈ GammaExponential for same γ and large α" begin
70101
v1 = randn(3)
71102
v2 = randn(3)
72103
γ = rand() + 0.5
73104
@test isapprox(
74-
GammaRationalQuadraticKernel(; α=1e9, γ=γ)(v1, v2),
105+
GammaRationalKernel(; α=1e9, γ=γ)(v1, v2),
75106
GammaExponentialKernel(; γ=γ)(v1, v2);
76107
atol=1e-6,
77108
rtol=1e-6,
78109
)
79110
end
80111

81-
@test metric(GammaRationalQuadraticKernel()) == Euclidean()
82-
@test metric(GammaRationalQuadraticKernel(; γ=2.0)) == Euclidean()
83-
@test metric(GammaRationalQuadraticKernel(; γ=2.0, α=3.0)) == Euclidean()
112+
@test metric(GammaRationalKernel()) == Euclidean()
113+
@test metric(GammaRationalKernel(; γ=2.0)) == Euclidean()
114+
@test metric(GammaRationalKernel(; γ=2.0, α=3.0)) == Euclidean()
115+
116+
# Deprecations.
117+
a = rand()
118+
g = 2 * rand()
119+
@test GammaRationalQuadraticKernel()(v1, v2) == GammaRationalKernel()(v1, v2)
120+
@test GammaRationalQuadraticKernel(; γ=g)(v1, v2) ==
121+
GammaRationalKernel(; γ=g)(v1, v2)
122+
@test GammaRationalQuadraticKernel(; γ=g, α=a)(v1, v2) ==
123+
GammaRationalKernel(; γ=g, α=a)(v1, v2)
84124

85125
# Standardised tests.
86126
TestUtils.test_interface(k, Float64)
87127
a = 1.0 + rand()
88-
test_ADs(x -> GammaRationalQuadraticKernel(; α=x[1], γ=x[2]), [a, 1 + 0.5 * rand()])
89-
test_params(GammaRationalQuadraticKernel(; α=a, γ=x), ([a], [x]))
128+
test_ADs(x -> GammaRationalKernel(; α=x[1], γ=x[2]), [a, 1 + 0.5 * rand()])
129+
test_params(GammaRationalKernel(; α=a, γ=x), ([a], [x]))
90130
end
91131
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ using KernelFunctions.TestUtils: test_interface
5151
@info "Packages Loaded"
5252

5353
include("test_utils.jl")
54+
include(joinpath("basekernels", "exponential.jl"))
55+
include(joinpath("basekernels", "rationalquad.jl"))
56+
include(joinpath("basekernels", "wiener.jl"))
57+
exit(0)
5458

5559
@testset "KernelFunctions" begin
5660
include("utils.jl")

0 commit comments

Comments
 (0)