Skip to content

Commit df3819a

Browse files
authored
Add Spectral Mixture Kernel (#80)
* Add Spectral Mixture Kernel * Define Kappa * Move to basekernels folder * Remove kappa * Update SM kernel * Add StretchTransform * Reduce SpectralMixtureKernel to a function, add tests and modify docstring * Add SpectralMixtureProductKernel * Fix SpectralMixtureProductKernel * Fix tests * Change name to lowercase and fix bugs in doc * Remove StretchTransform * Fix typo in docstring * Remove unnecessary line from test * Fix typo in docstring * Improve docstring * Change errors thrown to DimensionMismatch * Add SqExponentialKernel as default kernel * Fix line length * Update docstrings
1 parent 2e92ecd commit df3819a

File tree

4 files changed

+136
-0
lines changed

4 files changed

+136
-0
lines changed

src/KernelFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LinearTransfo
2828

2929
export NystromFact, nystrom
3030

31+
export spectral_mixture_kernel, spectral_mixture_product_kernel
32+
33+
3134
using Compat
3235
using Requires
3336
using Distances, LinearAlgebra

src/basekernels/sm.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
spectral_mixture_kernel(
3+
h::Kernel=SqExponentialKernel(),
4+
αs::AbstractVector{<:Real},
5+
γs::AbstractMatrix{<:Real},
6+
ωs::AbstractMatrix{<:Real},
7+
)
8+
9+
where αs are the weights of dimension (A, ), γs is the covariance matrix of
10+
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
11+
Here, D is input dimension and A is the number of spectral components.
12+
13+
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
14+
15+
Generalised Spectral Mixture kernel function. This family of functions is dense
16+
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
17+
18+
```math
19+
κ(x, y) = αs' (h(-(γs' * t)^2) .* cos(π * ωs' * t), t = x - y
20+
```
21+
22+
# References:
23+
[1] Generalized Spectral Kernels, by Yves-Laurent Kom Samo and Stephen J. Roberts
24+
[2] SM: Gaussian Process Kernels for Pattern Discovery and Extrapolation,
25+
ICML, 2013, by Andrew Gordon Wilson and Ryan Prescott Adams,
26+
[3] Covariance kernels for fast automatic pattern discovery and extrapolation
27+
with Gaussian processes, Andrew Gordon Wilson, PhD Thesis, January 2014.
28+
http://www.cs.cmu.edu/~andrewgw/andrewgwthesis.pdf
29+
[4] http://www.cs.cmu.edu/~andrewgw/pattern/.
30+
31+
"""
32+
function spectral_mixture_kernel(
33+
h::Kernel,
34+
αs::AbstractVector{<:Real},
35+
γs::AbstractMatrix{<:Real},
36+
ωs::AbstractMatrix{<:Real},
37+
)
38+
if !(size(αs, 1) == size(γs, 2) == size(ωs, 2))
39+
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
40+
end
41+
if size(γs) != size(ωs)
42+
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
43+
end
44+
45+
return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
46+
a = TransformedKernel(h, LinearTransform'))
47+
b = TransformedKernel(CosineKernel(), LinearTransform'))
48+
return α * a * b
49+
end
50+
end
51+
52+
function spectral_mixture_kernel(
53+
αs::AbstractVector{<:Real},
54+
γs::AbstractMatrix{<:Real},
55+
ωs::AbstractMatrix{<:Real}
56+
)
57+
spectral_mixture_kernel(SqExponentialKernel(), αs, γs, ωs)
58+
end
59+
60+
"""
61+
spectral_mixture_product_kernel(
62+
h::Kernel=SqExponentialKernel(),
63+
αs::AbstractMatrix{<:Real},
64+
γs::AbstractMatrix{<:Real},
65+
ωs::AbstractMatrix{<:Real},
66+
)
67+
68+
where αs are the weights of dimension (D, A), γs is the covariance matrix of
69+
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
70+
Here, D is input dimension and A is the number of spectral components.
71+
72+
Spectral Mixture Product Kernel. With enough components A, the SMP kernel
73+
can model any product kernel to arbitrary precision, and is flexible even
74+
with a small number of components [1]
75+
76+
77+
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
78+
79+
```math
80+
κ(x, y) = Πᵢ₌₁ᴷ Σ(αsᵢᵀ .* (h(-(γsᵢᵀ * tᵢ)²) .* cos(ωsᵢᵀ * tᵢ))), tᵢ = xᵢ - yᵢ
81+
```
82+
83+
# References:
84+
[1] GPatt: Fast Multidimensional Pattern Extrapolation with GPs,
85+
arXiv 1310.5288, 2013, by Andrew Gordon Wilson, Elad Gilboa,
86+
Arye Nehorai and John P. Cunningham
87+
"""
88+
function spectral_mixture_product_kernel(
89+
h::Kernel,
90+
αs::AbstractMatrix{<:Real},
91+
γs::AbstractMatrix{<:Real},
92+
ωs::AbstractMatrix{<:Real},
93+
)
94+
if !(size(αs) == size(γs) == size(ωs))
95+
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
96+
end
97+
return TensorProduct(spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :))
98+
for (α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs)))
99+
end
100+
101+
function spectral_mixture_product_kernel(
102+
αs::AbstractMatrix{<:Real},
103+
γs::AbstractMatrix{<:Real},
104+
ωs::AbstractMatrix{<:Real}
105+
)
106+
spectral_mixture_product_kernel(SqExponentialKernel(), αs, γs, ωs)
107+
end
108+

test/basekernels/sm.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "sm" begin
2+
v1 = rand(5)
3+
v2 = rand(5)
4+
5+
αs₁ = rand(3)
6+
αs₂ = rand(5, 3)
7+
γs = rand(5, 3)
8+
ωs = rand(5, 3)
9+
10+
k1 = spectral_mixture_kernel(αs₁, γs, ωs)
11+
k2 = spectral_mixture_product_kernel(αs₂, γs, ωs)
12+
13+
t = v1 - v2
14+
15+
@test k1(v1, v2) sum(αs₁ .* exp.(-(t' * γs)'.^2) .*
16+
cospi.((t' * ωs)')) atol=1e-5
17+
18+
@test k2(v1, v2) prod(sum(αs₂[i,:]' .* exp.(-(γs[i,:]' * t[i]).^2) .*
19+
cospi.(ωs[i,:]' * t[i])) for i in 1:length(t)) atol=1e-5
20+
21+
@test_throws DimensionMismatch spectral_mixture_kernel(rand(5) ,rand(4,3), rand(4,3))
22+
@test_throws DimensionMismatch spectral_mixture_kernel(rand(3) ,rand(4,3), rand(5,3))
23+
@test_throws DimensionMismatch spectral_mixture_product_kernel(rand(5,3) ,rand(4,3), rand(5,3))
24+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ using KernelFunctions: metric, kappa
7676
include(joinpath("basekernels", "polynomial.jl"))
7777
include(joinpath("basekernels", "piecewisepolynomial.jl"))
7878
include(joinpath("basekernels", "rationalquad.jl"))
79+
include(joinpath("basekernels", "sm.jl"))
7980
include(joinpath("basekernels", "wiener.jl"))
8081
end
8182

0 commit comments

Comments
 (0)