Skip to content

Add Latent-factor multi-output kernel #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.7.1"
version = "0.7.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
6 changes: 4 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!
export transform
export duplicate, set! # Helpers

export Kernel
export Kernel, MOKernel
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel
export CosineKernel
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
Expand All @@ -32,7 +32,7 @@ export NystromFact, nystrom
export spectral_mixture_kernel, spectral_mixture_product_kernel

export MOInput
export IndependentMOKernel
export IndependentMOKernel, LatentFactorMOKernel

using Compat
using Requires
Expand Down Expand Up @@ -72,8 +72,10 @@ include("kernels/tensorproduct.jl")
include("approximations/nystrom.jl")
include("generic.jl")

include("mokernels/mokernel.jl")
include("mokernels/moinput.jl")
include("mokernels/independent.jl")
include("mokernels/slfm.jl")

include("zygote_adjoints.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ metric(κ::PeriodicKernel) = Sinus(κ.r)

kappa(κ::PeriodicKernel, d::Real) = exp(- 0.5d)

Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = ", length(κ.r), ")")
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel (length(r) = ", length(κ.r), ")")
2 changes: 1 addition & 1 deletion src/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

A Multi-Output kernel which assumes each output is independent of the other.
"""
struct IndependentMOKernel{Tkernel<:Kernel} <: Kernel
struct IndependentMOKernel{Tkernel<:Kernel} <: MOKernel
kernel::Tkernel
end

Expand Down
6 changes: 6 additions & 0 deletions src/mokernels/mokernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
MOKernel

An abstract type for multi-output kernels.
"""
abstract type MOKernel <: Kernel end
62 changes: 62 additions & 0 deletions src/mokernels/slfm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@doc raw"""
LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix)

The kernel associated with the Semiparametric Latent Factor Model, introduced by
Seeger, Teh and Jordan (2005).

``k((x, p_x), (y, p_y)) = \Sum^{Q}_{q=1} A_{p_xq}g_q(x, y)A_{p_yq} + e((x, p_x), (y, p_y))``

# Arguments
- `g`: a collection of kernels, one for each latent process
- `e`: a [`MOKernel`](@ref) - multi-output kernel
- `A::AbstractMatrix`: a matrix of weights for the kernels of size `(out_dim, length(g))`


# Reference:
- [Seeger, Teh, and Jordan (2005)](https://infoscience.epfl.ch/record/161465/files/slfm-long.pdf)

"""
struct LatentFactorMOKernel{Tg, Te <: MOKernel, TA <: AbstractMatrix} <: MOKernel
g::Tg
e::Te
A::TA
function LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix)
all(gi isa Kernel for gi in g) || error("`g` should be an collection of kernels")
length(g) == size(A, 2) ||
error("Size of `A` not compatible with the given array of kernels `g`")
return new{typeof(g), typeof(e), typeof(A)}(g, e, A)
end
end

function (κ::LatentFactorMOKernel)((x, px)::Tuple{Any, Int}, (y, py)::Tuple{Any, Int})
cov_f = sum(κ.A[px, q] * κ.g[q](x, y) * κ.A[py, q] for q in 1:length(κ.g))
return cov_f + κ.e((x, px), (y, py))
end

function kernelmatrix(k::LatentFactorMOKernel, x::MOInput, y::MOInput)
x.out_dim == y.out_dim || error("`x` and `y` should have the same output dimension")
x.out_dim == size(k.A, 1) ||
error("Kernel not compatible with the given multi-output inputs")

# Weights matrix ((out_dim x out_dim) x length(k.g))
W = [col * col' for col in eachcol(k.A)]

# Latent kernel matrix ((N x N) x length(k.g))
H = [gi.(x.x, permutedims(y.x)) for gi in k.g]

# Weighted latent kernel matrix ((N*out_dim) x (N*out_dim))
W_H = sum(kron(Wi, Hi) for (Wi, Hi) in zip(W, H))

return W_H .+ kernelmatrix(k.e, x, y)
end

function Base.show(io::IO, k::LatentFactorMOKernel)
print(io, "Semi-parametric Latent Factor Multi-Output Kernel")
end

function Base.show(io::IO, ::MIME"text/plain", k::LatentFactorMOKernel)
print(io, "Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: ")
join(io, k.g, "\n\t\t")
print(io, "\n\teᵢ: ")
join(io, k.e, "\n\t\t")
end
2 changes: 1 addition & 1 deletion test/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@test k(v1, v2) ≈ exp(-0.5 * sum(abs2, sinpi.(v1 - v2) ./ r))
@test k(v1, v2) == k(v2, v1)
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
@test repr(k) == "Periodic Kernel, length(r) = $(length(r)))"
@test repr(k) == "Periodic Kernel (length(r) = $(length(r)))"
# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff"
test_params(k, (r,))
Expand Down
3 changes: 2 additions & 1 deletion test/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

k = IndependentMOKernel(GaussianKernel())
@test k isa IndependentMOKernel
@test k isa MOKernel
@test k isa Kernel
@test k.kernel isa KernelFunctions.Kernel
@test k.kernel isa Kernel
@test k(x[2], y[2]) isa Real

@test kernelmatrix(k, x, y) == kernelmatrix(k, collect(x), collect(y))
Expand Down
47 changes: 47 additions & 0 deletions test/mokernels/slfm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@testset "slfm" begin
rng = MersenneTwister(123)
FDM = FiniteDifferences.central_fdm(5, 1)
N = 10
in_dim = 5
out_dim = 4
x1 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
x2 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)

k = LatentFactorMOKernel(
[MaternKernel(), SqExponentialKernel(), FBMKernel()],
IndependentMOKernel(GaussianKernel()),
rand(rng, out_dim, 3),
)
@test k isa LatentFactorMOKernel
@test k isa MOKernel
@test k isa Kernel
@test k(x1[1], x2[1]) isa Real

@test kernelmatrix(k, x1, x2) ≈ kernelmatrix(k, collect(x1), collect(x2))
@test kernelmatrix(k, x1, x1) ≈ kernelmatrix(k, x1)

@test string(k) == "Semi-parametric Latent Factor Multi-Output Kernel"
@test repr("text/plain", k) == (
"Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: " *
"Matern Kernel (ν = 1.5)\n\t\tSquared Exponential Kernel\n" *
"\t\tFractional Brownian Motion Kernel (h = 0.5)\n\teᵢ: " *
"Independent Multi-Output Kernel\n\tSquared Exponential Kernel"
)

# AD test
function test_slfm(A::AbstractMatrix, x1, x2)
k = LatentFactorMOKernel(
[MaternKernel(), SqExponentialKernel(), FBMKernel()],
IndependentMOKernel(GaussianKernel()),
A,
)
return k((x1, 1), (x2, 1))
end

a = rand()
@test all(
FiniteDifferences.j′vp(FDM, test_slfm, a, k.A, x1[1][1], x2[1][1]) .≈
Zygote.pullback(test_slfm, k.A, x1[1][1], x2[1][1])[2](a)
)

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ include("test_utils.jl")
@testset "multi_output" begin
include(joinpath("mokernels", "moinput.jl"))
include(joinpath("mokernels", "independent.jl"))
include(joinpath("mokernels", "slfm.jl"))
end
@info "Ran tests on Multi-Output Kernels"

Expand Down