Skip to content

Add Latent-factor multi-output kernel #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export NystromFact, nystrom
export spectral_mixture_kernel, spectral_mixture_product_kernel

export MOInput
export IndependentMOKernel
export IndependentMOKernel, LatentFactorMOKernel

using Compat
using Requires
Expand Down Expand Up @@ -73,6 +73,7 @@ include("generic.jl")

include("mokernels/moinput.jl")
include("mokernels/independent.jl")
include("mokernels/slfm.jl")

include("zygote_adjoints.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ metric(κ::PeriodicKernel) = Sinus(κ.r)

kappa(κ::PeriodicKernel, d::Real) = exp(- 0.5d)

Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = ", length(κ.r), ")")
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel (length(r) = ", length(κ.r), ")")
63 changes: 63 additions & 0 deletions src/mokernels/slfm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@doc raw"""
LatentFactorMOKernel(
g::AbstractVector{<:Kernel},
e::AbstractVector{<:Kernel},
A::AbstractMatrix
)

A semiparametric kernel for problems involving multiple response variables.

``k((x, p), (y, p)) = k_p(x, y) = \Sum^{Q}_{q=1} A_{pq}g_q(x, y) + e_p(x, y)``

# Arguments
- `g::AbstractVector{<:Kernel}`: an array of kernels
- `e::AbstractVector{<:Kernel}`: an array of kernels
- `A::AbstractMatrix`: an matrix of weights for the kernels of shape (length(e), length(g))


# Reference:
- [Seeger, Teh, and Jordan (2005)](https://infoscience.epfl.ch/record/161465/files/slfm-long.pdf)

"""
struct LatentFactorMOKernel{Tg, Te, TA <: AbstractMatrix} <: Kernel
g::Tg
e::Te
A::TA
function LatentFactorMOKernel(g, e, A::AbstractMatrix)
all(isa.(g, Kernel)) || error("`g` should be an collection of kernels")
all(isa.(e, Kernel)) || error("`e` should be an collection of kernels")
(length(e), length(g)) == size(A) ||
error("Size of A not compatible to the given array of kernels")
return new{typeof(g), typeof(e), typeof(A)}(g, e, A)
end
end

function (κ::LatentFactorMOKernel)((x, px)::Tuple{Vector, Int}, (y, py)::Tuple{Vector, Int})
if px == py
return sum([κ.g[i](x, y) * κ.A[px, i] for i in 1:length(κ.g)]) +
κ.e[px](x, y)
else
return 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a type instability, maybe we can avoid it by computing some dummy value z of the same type as in the other branch and then return zero(z)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean. How do we ensure the computed dummy value is of the same type without actually executing it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure, I was still wondering what's the best way to do it. Sometimes one can perform some possibly cheaper dummy operation by, e.g., using zeros instead of actually indexing etc. Alternatively one could always evaluate the first branch and just call zero(res) on its results res - however, that will perform superfluous computations if last(x) != last(y).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you worried about the first branch not giving a Float64 output? When could this possibly happen?
Maybe we could explicity impose Float64 on both the outputs.

function::LatentFactorMOKernel)((x, px)::Tuple{Vector, Int}, (y, py)::Tuple{Vector, Int})
    if px == py
        return Float64(sum([κ.g[i](x, y) * κ.A[px, i] for i in 1:length.g)]) + 
            κ.e[px](x, y))
    else
        return Float64(0.0)
    end
end

I am not sure if this would cause any problems with AD.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it could happen for dual numbers (or basically any other number types). Type instability occurs even if the kernels return values of type Float32 and the elements of A are of type Float32. IMO this example also shows that it's not a good idea to enforce Float64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt @theogf Any suggestions on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely clear on why there's this zero branch as the SLFM produces non-zero covariance between all outputs. Am I missing something?

Copy link
Contributor Author

@sharanry sharanry Aug 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. Which kernel(s) from g and e would we use if px!=py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt Is this resolved? I am still unsure how to handle cases when px!=py.

end
end

function kernelmatrix(k::LatentFactorMOKernel, x::MOInput, y::MOInput)
x.out_dim == y.out_dim || error("`x` and `y` should have the same output dimension")
x.out_dim == size(k.A, 1) ||
error("Kernel not compatible with the given multi-output inputs")
return k.(x, permutedims(collect(y)))
end

function Base.show(io::IO, k::LatentFactorMOKernel)
print(io, "Semi-parametric Latent Factor Multi-Output Kernel")
end

function Base.show(io::IO, ::MIME"text/plain", k::LatentFactorMOKernel)
print(
io,
"Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: ",
[string(gi, "\n\t\t") for gi in k.g]...,
"\n\teᵢ: ",
[string(ei, "\n\t\t") for ei in k.e]...,
)
end
10 changes: 5 additions & 5 deletions test/mokernels/independent.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
@testset "independent" begin
x = MOInput([rand(5) for _ in 1:4], 3)
y = MOInput([rand(5) for _ in 1:4], 3)
x1 = MOInput([rand(5) for _ in 1:4], 3)
x2 = MOInput([rand(5) for _ in 1:4], 3)

k = IndependentMOKernel(GaussianKernel())
@test k isa IndependentMOKernel
@test k isa Kernel
@test k.kernel isa KernelFunctions.BaseKernel
@test k(x[2], y[2]) isa Real
@test k(x1[2], x2[2]) isa Real

@test kernelmatrix(k, x, y) == kernelmatrix(k, collect(x), collect(y))
@test kernelmatrix(k, x, x) == kernelmatrix(k, x)
@test kernelmatrix(k, x1, x2) == kernelmatrix(k, collect(x1), collect(x2))
@test kernelmatrix(k, x1, x1) == kernelmatrix(k, x1)
@test string(k) == "Independent Multi-Output Kernel\n\tSquared Exponential Kernel"
end
17 changes: 17 additions & 0 deletions test/mokernels/slfm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@testset "slfm" begin
x1 = MOInput([rand(5) for _ in 1:4], 2)
x2 = MOInput([rand(5) for _ in 1:4], 2)

k = LatentFactorMOKernel(
[MaternKernel(), SqExponentialKernel(), FBMKernel()],
[ExponentialKernel(), PeriodicKernel(5)],
rand(2, 3)
)
@test k isa LatentFactorMOKernel
@test k isa Kernel
@test k(x1[2], x2[2]) isa Real

@test kernelmatrix(k, x1, x2) == kernelmatrix(k, collect(x1), collect(x2))
@test kernelmatrix(k, x1, x1) == kernelmatrix(k, x1)

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs
@testset "multi_output" begin
include(joinpath("mokernels", "moinput.jl"))
include(joinpath("mokernels", "independent.jl"))
include(joinpath("mokernels", "slfm.jl"))
end
@info "Ran tests on Multi-Output Kernels"

Expand Down