Skip to content

Commit c48301e

Browse files
authored
Add datatype for multi-output GP input and add independent multi-output kernel (#138)
* Add datatype for multi-output GP input * Add IndependentKernel * Address code review * Address code review * Address code review * Update Tests * Redefine Indendent kernel and add kernelmatrix function * Use block diagonal outputs and fix kernel and kix style issues * Remove exports for base functions. * Remove MOKernel type and make MOInput a subtype of AbstractVector * Make Delta metric return promoted type of the input arrays * Promote and convert in single step * Add more tests
1 parent fca569c commit c48301e

File tree

9 files changed

+115
-2
lines changed

9 files changed

+115
-2
lines changed

src/KernelFunctions.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ export NystromFact, nystrom
3030

3131
export spectral_mixture_kernel, spectral_mixture_product_kernel
3232

33+
export MOInput
34+
export IndependentMOKernel
3335

3436
using Compat
3537
using Requires
@@ -69,6 +71,9 @@ include("kernels/tensorproduct.jl")
6971
include("approximations/nystrom.jl")
7072
include("generic.jl")
7173

74+
include("mokernels/moinput.jl")
75+
include("mokernels/independent.jl")
76+
7277
include("zygote_adjoints.jl")
7378

7479
function __init__()

src/distances/delta.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
struct Delta <: Distances.PreMetric
22
end
33

4-
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
4+
@inline function Distances._evaluate(::Delta, a::AbstractVector{Ta}, b::AbstractVector{Tb}) where {Ta, Tb}
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
8-
return a == b
8+
return convert(promote_type(Ta, Tb), a == b)
99
end
1010

1111
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)

src/mokernels/independent.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
IndependentMOKernel(k::Kernel) <: Kernel
3+
4+
A Multi-Output kernel which assumes each output is independent of the other.
5+
"""
6+
struct IndependentMOKernel{Tkernel<:Kernel} <: Kernel
7+
kernel::Tkernel
8+
end
9+
10+
function::IndependentMOKernel)(x::Tuple{Vector, Int}, y::Tuple{Vector, Int})
11+
if last(x) == last(y)
12+
return κ.kernel(first(x), first(y))
13+
else
14+
return 0.0
15+
end
16+
end
17+
18+
function kernelmatrix(k::IndependentMOKernel, x::MOInput, y::MOInput)
19+
@assert x.out_dim == y.out_dim
20+
temp = k.kernel.(x.x, permutedims(y.x))
21+
return cat((temp for _ in 1:y.out_dim)...; dims=(1,2))
22+
end
23+
24+
function Base.show(io::IO, k::IndependentMOKernel)
25+
print(io, string("Independent Multi-Output Kernel\n\t", string(k.kernel)))
26+
end

src/mokernels/moinput.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
MOInput(x::AbstractVector, out_dim::Integer)
3+
4+
A data type to accomodate modelling multi-dimensional output data.
5+
"""
6+
struct MOInput{T<:AbstractVector} <: AbstractVector{Tuple{Any,Int}}
7+
x::T
8+
out_dim::Integer
9+
end
10+
11+
Base.length(inp::MOInput) = inp.out_dim * length(inp.x)
12+
13+
Base.size(inp::MOInput, d) = d::Integer == 1 ? inp.out_dim * size(inp.x, 1) : 1
14+
Base.size(inp::MOInput) = (inp.out_dim * size(inp.x, 1),)
15+
16+
Base.lastindex(inp::MOInput) = length(inp)
17+
Base.firstindex(inp::MOInput) = 1
18+
19+
function Base.getindex(inp::MOInput, ind::Integer)
20+
if ind > 0
21+
out_dim = ind ÷ length(inp.x) + 1
22+
ind = ind % length(inp.x)
23+
if ind==0 ind = length(inp.x); out_dim-=1 end
24+
return (inp.x[ind], out_dim::Int)
25+
else
26+
throw(BoundsError(string("Trying to access at ", ind)))
27+
end
28+
end
29+
30+
Base.iterate(inp::MOInput) = (inp[1], 1)
31+
Base.iterate(inp::MOInput, state) = (state<length(inp)) ? (inp[state + 1], state + 1) : nothing

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
2121
end
2222

2323
dim(x::AbstractVector{<:Real}) = 1
24+
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
2425

2526
"""
2627
ColVecs(X::AbstractMatrix)

test/matrix/kernelmatrix.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,13 @@ KernelFunctions.kappa(::ToySimpleKernel, d) = exp(-d / 2)
130130
@test kerneldiagmatrix(k, x) kerneldiagmatrix!(tmp_diag, k, X; obsdim=obsdim)
131131
end
132132
end
133+
134+
@testset "Multi Output Kernels" begin
135+
x = MOInput([rand(5) for _ in 1:4], 3)
136+
y = MOInput([rand(5) for _ in 1:4], 3)
137+
138+
k = IndependentMOKernel(GaussianKernel())
139+
@test kernelmatrix(k, x, y) == k.(collect(x), permutedims(collect(y)))
140+
@test kernelmatrix(k, x, x) == kernelmatrix(k, x)
141+
end
133142
end

test/mokernels/independent.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
@testset "independent" begin
2+
x = MOInput([rand(5) for _ in 1:4], 3)
3+
y = MOInput([rand(5) for _ in 1:4], 3)
4+
5+
k = IndependentMOKernel(GaussianKernel())
6+
@test k isa IndependentMOKernel
7+
@test k isa Kernel
8+
@test k.kernel isa KernelFunctions.BaseKernel
9+
@test k(x[2], y[2]) isa Real
10+
11+
@test kernelmatrix(k, x, y) == kernelmatrix(k, collect(x), collect(y))
12+
@test kernelmatrix(k, x, x) == kernelmatrix(k, x)
13+
@test string(k) == "Independent Multi-Output Kernel\n\tSquared Exponential Kernel"
14+
end

test/mokernels/moinput.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
@testset "moinput" begin
2+
3+
x = [rand(5) for _ in 1:4]
4+
mgpi = MOInput(x, 3)
5+
6+
@test length(mgpi) == 12
7+
@test size(mgpi) == (12,)
8+
@test size(mgpi, 1) == 12
9+
@test size(mgpi, 2) == 1
10+
@test lastindex(mgpi) == 12
11+
@test firstindex(mgpi) == 1
12+
@test iterate(mgpi) == (mgpi[1], 1)
13+
@test iterate(mgpi, 2) == (mgpi[3], 3)
14+
@test_throws BoundsError mgpi[0]
15+
16+
@test mgpi[2] == (x[2], 1)
17+
@test mgpi[5] == (x[1], 2)
18+
@test mgpi[7] == (x[3], 2)
19+
@test all([(x_, i) for i in 1:3 for x_ in x ] .== mgpi)
20+
21+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs
105105
end
106106
@info "Ran tests on matrix"
107107

108+
@testset "multi_output" begin
109+
include(joinpath("mokernels", "moinput.jl"))
110+
include(joinpath("mokernels", "independent.jl"))
111+
end
112+
@info "Ran tests on Multi-Output Kernels"
113+
108114
@testset "approximations" begin
109115
include(joinpath("approximations", "nystrom.jl"))
110116
end

0 commit comments

Comments
 (0)