Skip to content

Commit 8c99314

Browse files
authored
Add Latent-factor multi-output kernel (#143)
* Add Latent-factor multi-output kernel * Add docstring * Address code review * Fix test * Modify show function and contructor input check * Address code review * Fix docstring formatting * Fix typo and add tests for show * Use 'Tuple{Any, Int}' * Fix kernel, update docs, add MOKernel type * Export MOKernel * Make kernelmatrix more efficient * avoid unpacking generators * Add Zygote AD test * Adress code review * Patch bump
1 parent 3031ef7 commit 8c99314

File tree

10 files changed

+126
-7
lines changed

10 files changed

+126
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.7.1"
3+
version = "0.7.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/KernelFunctions.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!
88
export transform
99
export duplicate, set! # Helpers
1010

11-
export Kernel
11+
export Kernel, MOKernel
1212
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel
1313
export CosineKernel
1414
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
@@ -32,7 +32,7 @@ export NystromFact, nystrom
3232
export spectral_mixture_kernel, spectral_mixture_product_kernel
3333

3434
export MOInput
35-
export IndependentMOKernel
35+
export IndependentMOKernel, LatentFactorMOKernel
3636

3737
using Compat
3838
using Requires
@@ -72,8 +72,10 @@ include("kernels/tensorproduct.jl")
7272
include("approximations/nystrom.jl")
7373
include("generic.jl")
7474

75+
include("mokernels/mokernel.jl")
7576
include("mokernels/moinput.jl")
7677
include("mokernels/independent.jl")
78+
include("mokernels/slfm.jl")
7779

7880
include("zygote_adjoints.jl")
7981

src/basekernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ metric(κ::PeriodicKernel) = Sinus(κ.r)
2626

2727
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)
2828

29-
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = ", length.r), ")")
29+
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel (length(r) = ", length.r), ")")

src/mokernels/independent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
A Multi-Output kernel which assumes each output is independent of the other.
55
"""
6-
struct IndependentMOKernel{Tkernel<:Kernel} <: Kernel
6+
struct IndependentMOKernel{Tkernel<:Kernel} <: MOKernel
77
kernel::Tkernel
88
end
99

src/mokernels/mokernel.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
MOKernel
3+
4+
An abstract type for multi-output kernels.
5+
"""
6+
abstract type MOKernel <: Kernel end

src/mokernels/slfm.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
@doc raw"""
2+
LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix)
3+
4+
The kernel associated with the Semiparametric Latent Factor Model, introduced by
5+
Seeger, Teh and Jordan (2005).
6+
7+
``k((x, p_x), (y, p_y)) = \Sum^{Q}_{q=1} A_{p_xq}g_q(x, y)A_{p_yq} + e((x, p_x), (y, p_y))``
8+
9+
# Arguments
10+
- `g`: a collection of kernels, one for each latent process
11+
- `e`: a [`MOKernel`](@ref) - multi-output kernel
12+
- `A::AbstractMatrix`: a matrix of weights for the kernels of size `(out_dim, length(g))`
13+
14+
15+
# Reference:
16+
- [Seeger, Teh, and Jordan (2005)](https://infoscience.epfl.ch/record/161465/files/slfm-long.pdf)
17+
18+
"""
19+
struct LatentFactorMOKernel{Tg, Te <: MOKernel, TA <: AbstractMatrix} <: MOKernel
20+
g::Tg
21+
e::Te
22+
A::TA
23+
function LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix)
24+
all(gi isa Kernel for gi in g) || error("`g` should be an collection of kernels")
25+
length(g) == size(A, 2) ||
26+
error("Size of `A` not compatible with the given array of kernels `g`")
27+
return new{typeof(g), typeof(e), typeof(A)}(g, e, A)
28+
end
29+
end
30+
31+
function::LatentFactorMOKernel)((x, px)::Tuple{Any, Int}, (y, py)::Tuple{Any, Int})
32+
cov_f = sum.A[px, q] * κ.g[q](x, y) * κ.A[py, q] for q in 1:length.g))
33+
return cov_f + κ.e((x, px), (y, py))
34+
end
35+
36+
function kernelmatrix(k::LatentFactorMOKernel, x::MOInput, y::MOInput)
37+
x.out_dim == y.out_dim || error("`x` and `y` should have the same output dimension")
38+
x.out_dim == size(k.A, 1) ||
39+
error("Kernel not compatible with the given multi-output inputs")
40+
41+
# Weights matrix ((out_dim x out_dim) x length(k.g))
42+
W = [col * col' for col in eachcol(k.A)]
43+
44+
# Latent kernel matrix ((N x N) x length(k.g))
45+
H = [gi.(x.x, permutedims(y.x)) for gi in k.g]
46+
47+
# Weighted latent kernel matrix ((N*out_dim) x (N*out_dim))
48+
W_H = sum(kron(Wi, Hi) for (Wi, Hi) in zip(W, H))
49+
50+
return W_H .+ kernelmatrix(k.e, x, y)
51+
end
52+
53+
function Base.show(io::IO, k::LatentFactorMOKernel)
54+
print(io, "Semi-parametric Latent Factor Multi-Output Kernel")
55+
end
56+
57+
function Base.show(io::IO, ::MIME"text/plain", k::LatentFactorMOKernel)
58+
print(io, "Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: ")
59+
join(io, k.g, "\n\t\t")
60+
print(io, "\n\teᵢ: ")
61+
join(io, k.e, "\n\t\t")
62+
end

test/basekernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
@test k(v1, v2) exp(-0.5 * sum(abs2, sinpi.(v1 - v2) ./ r))
77
@test k(v1, v2) == k(v2, v1)
88
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
9-
@test repr(k) == "Periodic Kernel, length(r) = $(length(r)))"
9+
@test repr(k) == "Periodic Kernel (length(r) = $(length(r)))"
1010
# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
1111
@test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff"
1212
test_params(k, (r,))

test/mokernels/independent.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
k = IndependentMOKernel(GaussianKernel())
66
@test k isa IndependentMOKernel
7+
@test k isa MOKernel
78
@test k isa Kernel
8-
@test k.kernel isa KernelFunctions.Kernel
9+
@test k.kernel isa Kernel
910
@test k(x[2], y[2]) isa Real
1011

1112
@test kernelmatrix(k, x, y) == kernelmatrix(k, collect(x), collect(y))

test/mokernels/slfm.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
@testset "slfm" begin
2+
rng = MersenneTwister(123)
3+
FDM = FiniteDifferences.central_fdm(5, 1)
4+
N = 10
5+
in_dim = 5
6+
out_dim = 4
7+
x1 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
8+
x2 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
9+
10+
k = LatentFactorMOKernel(
11+
[MaternKernel(), SqExponentialKernel(), FBMKernel()],
12+
IndependentMOKernel(GaussianKernel()),
13+
rand(rng, out_dim, 3),
14+
)
15+
@test k isa LatentFactorMOKernel
16+
@test k isa MOKernel
17+
@test k isa Kernel
18+
@test k(x1[1], x2[1]) isa Real
19+
20+
@test kernelmatrix(k, x1, x2) kernelmatrix(k, collect(x1), collect(x2))
21+
@test kernelmatrix(k, x1, x1) kernelmatrix(k, x1)
22+
23+
@test string(k) == "Semi-parametric Latent Factor Multi-Output Kernel"
24+
@test repr("text/plain", k) == (
25+
"Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: " *
26+
"Matern Kernel (ν = 1.5)\n\t\tSquared Exponential Kernel\n" *
27+
"\t\tFractional Brownian Motion Kernel (h = 0.5)\n\teᵢ: " *
28+
"Independent Multi-Output Kernel\n\tSquared Exponential Kernel"
29+
)
30+
31+
# AD test
32+
function test_slfm(A::AbstractMatrix, x1, x2)
33+
k = LatentFactorMOKernel(
34+
[MaternKernel(), SqExponentialKernel(), FBMKernel()],
35+
IndependentMOKernel(GaussianKernel()),
36+
A,
37+
)
38+
return k((x1, 1), (x2, 1))
39+
end
40+
41+
a = rand()
42+
@test all(
43+
FiniteDifferences.j′vp(FDM, test_slfm, a, k.A, x1[1][1], x2[1][1]) .≈
44+
Zygote.pullback(test_slfm, k.A, x1[1][1], x2[1][1])[2](a)
45+
)
46+
47+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ include("test_utils.jl")
117117
@testset "multi_output" begin
118118
include(joinpath("mokernels", "moinput.jl"))
119119
include(joinpath("mokernels", "independent.jl"))
120+
include(joinpath("mokernels", "slfm.jl"))
120121
end
121122
@info "Ran tests on Multi-Output Kernels"
122123

0 commit comments

Comments
 (0)