Skip to content

Commit dadfbb3

Browse files
Crown421st--willtebbutt
authored
Some improvement and additions to MO kernels (#354)
* Fix type stability of independent kernel * Add convenience functions * Add/fix tests * Suggested change for kernelmatrix computation * Add results to test script * Switching to new kernelmatrix for independent kernel, tests * Change type stability * New kernelmatrix for IntrinsicCoregion kernel, tests * JuliaFormatter * Improve code reuse, Traits * Use FillArray * Remove traits, move to explicit function * Relax matrixkernel types Co-authored-by: st-- <[email protected]> * Adjust MOInput specifications * Change supertype, move matrixkernel * JuliaFormatter * Remove forced typesymmetry, and improve fallback * Improve documentation, add matrixkernel to other kernels * Comment out specialized kernelmatrix! * Add kernelmatrix! back in with version check * Formatter * Improve doc phrasing Co-authored-by: st-- <[email protected]> * Change Union name * Substantially improve tests * Add ed clarity for lazy kronecker * Improve tests for lmm and slfm * Formatter * Update src/mokernels/mokernel.jl Co-authored-by: willtebbutt <[email protected]> * Remove comments Co-authored-by: willtebbutt <[email protected]> * Add check for lmm kernel, tests * Fix constructor mistake * Remove kwarg dispatch * Remove matrixkernel and lazy kron * Remove matrixkernel export Co-authored-by: willtebbutt <[email protected]> * Update src/matrix/kernelkroneckermat.jl Co-authored-by: willtebbutt <[email protected]> * Update src/mokernels/independent.jl Co-authored-by: willtebbutt <[email protected]> * Update src/mokernels/independent.jl Co-authored-by: willtebbutt <[email protected]> * Update src/mokernels/mokernel.jl Co-authored-by: willtebbutt <[email protected]> * helper function name, formatter * Remove explicit in-place tests * Remove temp file and bump patch * Change helper function name * Fixed rename oversight * Fix missing Union * Silly mistake * Make arguments more consistent * Forgot to uncomment Co-authored-by: st-- <[email protected]> Co-authored-by: willtebbutt <[email protected]>
1 parent 7286dd5 commit dadfbb3

File tree

10 files changed

+163
-45
lines changed

10 files changed

+163
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.15"
3+
version = "0.10.16"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/mokernels/independent.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,45 @@ struct IndependentMOKernel{Tkernel<:Kernel} <: MOKernel
2424
end
2525

2626
function::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int})
27-
if px == py
28-
return κ.kernel(x, y)
29-
else
30-
return 0.0
31-
end
27+
return κ.kernel(x, y) * (px == py)
28+
end
29+
30+
function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, B)
31+
return kron(Kfeatures, B)
3232
end
3333

34-
function kernelmatrix(k::IndependentMOKernel, x::MOInput, y::MOInput)
34+
function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, B)
35+
return kron(B, Kfeatures)
36+
end
37+
38+
function kernelmatrix(
39+
k::IndependentMOKernel, x::MOI, y::MOI
40+
) where {MOI<:IsotopicMOInputsUnion}
3541
@assert x.out_dim == y.out_dim
36-
temp = k.kernel.(x.x, permutedims(y.x))
37-
return cat((temp for _ in 1:(y.out_dim))...; dims=(1, 2))
42+
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
43+
mtype = eltype(Kfeatures)
44+
return _kernelmatrix_kron_helper(x, Kfeatures, Eye{mtype}(x.out_dim))
45+
end
46+
47+
if VERSION >= v"1.6"
48+
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, B)
49+
return kron!(K, Kfeatures, B)
50+
end
51+
52+
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, B)
53+
return kron!(K, B, Kfeatures)
54+
end
55+
56+
function kernelmatrix!(
57+
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
58+
) where {MOI<:IsotopicMOInputsUnion}
59+
@assert x.out_dim == y.out_dim
60+
Ktmp = kernelmatrix(k.kernel, x.x, y.x)
61+
mtype = eltype(Ktmp)
62+
return _kernelmatrix_kron_helper!(
63+
K, x, Ktmp, Matrix{mtype}(I, x.out_dim, x.out_dim)
64+
)
65+
end
3866
end
3967

4068
function Base.show(io::IO, k::IndependentMOKernel)

src/mokernels/intrinsiccoregion.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,32 @@ function IntrinsicCoregionMOKernel(; kernel::Kernel, B::AbstractMatrix)
3434
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B)
3535
end
3636

37+
function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix)
38+
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B)
39+
end
40+
3741
function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int})
3842
return k.B[px, py] * k.kernel(x, y)
3943
end
4044

45+
function kernelmatrix(
46+
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
47+
) where {MOI<:IsotopicMOInputsUnion}
48+
@assert x.out_dim == y.out_dim
49+
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
50+
return _kernelmatrix_kron_helper(x, Kfeatures, k.B)
51+
end
52+
53+
if VERSION >= v"1.6"
54+
function kernelmatrix!(
55+
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
56+
) where {MOI<:IsotopicMOInputsUnion}
57+
@assert x.out_dim == y.out_dim
58+
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
59+
return _kernelmatrix_kron_helper!(K, x, Kfeatures, k.B)
60+
end
61+
end
62+
4163
function Base.show(io::IO, k::IntrinsicCoregionMOKernel)
4264
return print(
4365
io, "Intrinsic Coregion Kernel: ", k.kernel, " with ", size(k.B, 1), " outputs"

src/mokernels/lmm.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
@doc raw"""
2-
LinearMixingModelKernel(g, e::MOKernel, A::AbstractMatrix)
2+
LinearMixingModelKernel(k::Kernel, H::AbstractMatrix)
3+
LinearMixingModelKernel(Tk::AbstractVector{<:Kernel},Th::AbstractMatrix)
34
4-
Kernel associated with the linear mixing model.
5+
Kernel associated with the linear mixing model, taking a vector of `m` kernels and a `m × p` matrix H for a function with `p` outputs. Also accepts a single kernel `k` for use across all `m` basis vectors.
56
67
# Definition
78
@@ -20,6 +21,10 @@ mixing matrix of ``m`` basis vectors spanning the output space.
2021
struct LinearMixingModelKernel{Tk<:AbstractVector{<:Kernel},Th<:AbstractMatrix} <: MOKernel
2122
K::Tk
2223
H::Th
24+
function LinearMixingModelKernel(Tk::AbstractVector{<:Kernel}, H::AbstractMatrix)
25+
@assert length(Tk) == size(H, 1) "Number of kernels and number of rows in H must match"
26+
return new{typeof(Tk),typeof(H)}(Tk, H)
27+
end
2328
end
2429

2530
function LinearMixingModelKernel(k::Kernel, H::AbstractMatrix)

src/mokernels/moinput.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct MOInputIsotopicByOutputs{S,T<:AbstractVector{S}} <: AbstractVector{Tuple{
5858
out_dim::Integer
5959
end
6060

61-
const IsotopicMOInputs = Union{MOInputIsotopicByFeatures,MOInputIsotopicByOutputs}
61+
const IsotopicMOInputsUnion = Union{MOInputIsotopicByFeatures,MOInputIsotopicByOutputs}
6262

6363
function Base.getindex(inp::MOInputIsotopicByOutputs, ind::Integer)
6464
@boundscheck checkbounds(inp, ind)
@@ -74,7 +74,7 @@ function Base.getindex(inp::MOInputIsotopicByFeatures, ind::Integer)
7474
return feature, output_index
7575
end
7676

77-
Base.size(inp::IsotopicMOInputs) = (inp.out_dim * length(inp.x),)
77+
Base.size(inp::IsotopicMOInputsUnion) = (inp.out_dim * length(inp.x),)
7878

7979
function Base.vcat(x::MOInputIsotopicByFeatures, y::MOInputIsotopicByFeatures)
8080
x.out_dim == y.out_dim || throw(DimensionMismatch("out_dim mismatch"))

src/mokernels/slfm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@doc raw"""
2-
LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix)
2+
LatentFactorMOKernel(g::AbstractVector{<:Kernel}, e::MOKernel, A::AbstractMatrix)
33
44
Kernel associated with the semiparametric latent factor model.
55

test/mokernels/independent.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
@testset "independent" begin
2-
x = MOInput([rand(5) for _ in 1:4], 3)
3-
y = MOInput([rand(5) for _ in 1:4], 3)
2+
outdim = 3
3+
x = KernelFunctions.MOInputIsotopicByOutputs([rand(5) for _ in 1:4], outdim)
4+
y = KernelFunctions.MOInputIsotopicByOutputs([rand(5) for _ in 1:4], outdim)
5+
z = KernelFunctions.MOInputIsotopicByOutputs([rand(5) for _ in 1:2], outdim)
6+
7+
xIF = KernelFunctions.MOInputIsotopicByFeatures(x.x, outdim)
8+
yIF = KernelFunctions.MOInputIsotopicByFeatures(y.x, outdim)
9+
zIF = KernelFunctions.MOInputIsotopicByFeatures(z.x, outdim)
410

511
k = IndependentMOKernel(GaussianKernel())
612
@test k isa IndependentMOKernel
713
@test k isa MOKernel
814
@test k isa Kernel
915
@test k.kernel isa Kernel
10-
@test k(x[2], y[2]) isa Real
1116

1217
@test kernelmatrix(k, x, y) == kernelmatrix(k, collect(x), collect(y))
13-
@test kernelmatrix(k, x, x) == kernelmatrix(k, x)
1418

15-
x1 = MOInput(rand(5), 3) # Single dim input
16-
@test k(x1[1], x1[1]) isa Real
17-
@test kernelmatrix(k, x1) isa Matrix
19+
## accuracy
20+
KernelFunctions.TestUtils.test_interface(k, x, y, z)
21+
KernelFunctions.TestUtils.test_interface(k, xIF, yIF, zIF)
22+
23+
# type stability (maybe move to test_interface?)
24+
x2 = MOInput(rand(Float32, 4), 2)
25+
@test k(x2[1], x2[2]) isa Float32
26+
@test k(x2[1], x2[1]) isa Float32
27+
@test eltype(typeof(kernelmatrix(k, x2))) <: Float32
1828

1929
@test string(k) ==
2030
"Independent Multi-Output Kernel\n" *

test/mokernels/intrinsiccoregion.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,39 @@
22
rng = MersenneTwister(123)
33

44
dims = (in=3, out=2, obs=3)
5-
rank = 1
5+
r = 1
66

7-
A = randn(dims.out, rank)
7+
A = randn(dims.out, r)
88
B = A * transpose(A) + Diagonal(rand(dims.out))
99

10-
X = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)]
10+
# XIF = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)]
11+
x = [rand(dims.in) for _ in 1:2]
12+
XIF = KernelFunctions.MOInputIsotopicByFeatures(x, dims.out)
13+
XIO = KernelFunctions.MOInputIsotopicByOutputs(x, dims.out)
14+
y = [rand(dims.in) for _ in 1:2]
15+
YIF = KernelFunctions.MOInputIsotopicByFeatures(y, dims.out)
16+
YIO = KernelFunctions.MOInputIsotopicByOutputs(y, dims.out)
17+
z = [rand(dims.in) for _ in 1:3]
18+
ZIF = KernelFunctions.MOInputIsotopicByFeatures(z, dims.out)
19+
ZIO = KernelFunctions.MOInputIsotopicByOutputs(z, dims.out)
1120

1221
kernel = SqExponentialKernel()
13-
icoregionkernel = IntrinsicCoregionMOKernel(; kernel=kernel, B=B)
22+
icoregionkernel = IntrinsicCoregionMOKernel(kernel, B)
23+
24+
icoregionkernel2 = IntrinsicCoregionMOKernel(; kernel=kernel, B=B)
25+
@test icoregionkernel == icoregionkernel2
1426

1527
@test icoregionkernel.B == B
1628
@test icoregionkernel.kernel == kernel
17-
@test icoregionkernel(X[1], X[1]) B[X[1][2], X[1][2]] * kernel(X[1][1], X[1][1])
18-
@test icoregionkernel(X[1], X[end]) B[X[1][2], X[end][2]] * kernel(X[1][1], X[end][1])
29+
@test icoregionkernel(XIF[1], XIF[1])
30+
B[XIF[1][2], XIF[1][2]] * kernel(XIF[1][1], XIF[1][1])
31+
@test icoregionkernel(XIF[1], XIF[end])
32+
B[XIF[1][2], XIF[end][2]] * kernel(XIF[1][1], XIF[end][1])
33+
34+
# kernelmatrix
35+
KernelFunctions.TestUtils.test_interface(icoregionkernel, XIF, YIF, ZIF)
36+
37+
KernelFunctions.TestUtils.test_interface(icoregionkernel, XIO, YIO, ZIO)
1938

2039
KernelFunctions.TestUtils.test_interface(
2140
icoregionkernel, Vector{Tuple{Float64,Int}}; dim_out=dims.out

test/mokernels/lmm.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
@testset "lmm" begin
22
rng = MersenneTwister(123)
33
FDM = FiniteDifferences.central_fdm(5, 1)
4-
N = 10
4+
N = 6
55
in_dim = 3
6-
out_dim = 6
7-
x1 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
8-
x2 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
9-
H = rand(4, 6)
10-
11-
k = LinearMixingModelKernel(
12-
[Matern32Kernel(), SqExponentialKernel(), FBMKernel(), Matern32Kernel()], H
6+
out_dim = 3
7+
x1IO = KernelFunctions.MOInputIsotopicByOutputs(
8+
[rand(rng, in_dim) for _ in 1:N], out_dim
9+
)
10+
x2IO = KernelFunctions.MOInputIsotopicByOutputs(
11+
[rand(rng, in_dim) for _ in 1:N], out_dim
12+
)
13+
x3IO = KernelFunctions.MOInputIsotopicByOutputs(
14+
[rand(rng, in_dim) for _ in 1:div(N, 2)], out_dim
1315
)
16+
17+
latentkernels = [Matern32Kernel(), SqExponentialKernel(), FBMKernel(), Matern32Kernel()]
18+
H = rand(length(latentkernels), out_dim)
19+
k = LinearMixingModelKernel(latentkernels, H)
20+
21+
badH = rand(length(latentkernels) - 1, out_dim)
22+
@test_throws AssertionError LinearMixingModelKernel(latentkernels, badH)
23+
1424
@test k isa LinearMixingModelKernel
1525
@test k isa MOKernel
1626
@test k isa Kernel
17-
@test k(x1[1], x2[1]) isa Real
27+
@test k(x1IO[1], x2IO[1]) isa Real
1828

1929
@test string(k) == "Linear Mixing Model Multi-Output Kernel"
2030
@test repr("text/plain", k) == (
@@ -25,6 +35,14 @@
2535
"\tMatern 3/2 Kernel (metric = Euclidean(0.0))"
2636
)
2737

38+
TestUtils.test_interface(k, x1IO, x2IO, x3IO)
39+
40+
x1IF = KernelFunctions.MOInputIsotopicByFeatures(x1IO.x, out_dim)
41+
x2IF = KernelFunctions.MOInputIsotopicByFeatures(x2IO.x, out_dim)
42+
x3IF = KernelFunctions.MOInputIsotopicByFeatures(x3IO.x, out_dim)
43+
44+
TestUtils.test_interface(k, x1IF, x2IF, x3IF)
45+
2846
k = LinearMixingModelKernel(SEKernel(), H)
2947

3048
@test k isa LinearMixingModelKernel

test/mokernels/slfm.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
@testset "slfm" begin
22
rng = MersenneTwister(123)
33
FDM = FiniteDifferences.central_fdm(5, 1)
4-
N = 10
5-
in_dim = 5
4+
N = 6
5+
in_dim = 3
66
out_dim = 4
7-
x1 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
8-
x2 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
7+
x1IO = KernelFunctions.MOInputIsotopicByOutputs(
8+
[rand(rng, in_dim) for _ in 1:N], out_dim
9+
)
10+
x2IO = KernelFunctions.MOInputIsotopicByOutputs(
11+
[rand(rng, in_dim) for _ in 1:N], out_dim
12+
)
13+
x3IO = KernelFunctions.MOInputIsotopicByOutputs(
14+
[rand(rng, in_dim) for _ in 1:div(N, 2)], out_dim
15+
)
16+
x1IO = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
17+
x2IO = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
918

1019
k = LatentFactorMOKernel(
1120
[Matern32Kernel(), SqExponentialKernel(), FBMKernel()],
@@ -15,10 +24,17 @@
1524
@test k isa LatentFactorMOKernel
1625
@test k isa MOKernel
1726
@test k isa Kernel
18-
@test k(x1[1], x2[1]) isa Real
27+
@test k(x1IO[1], x2IO[1]) isa Real
28+
29+
@test kernelmatrix(k, x1IO, x2IO) kernelmatrix(k, collect(x1IO), collect(x2IO))
30+
31+
TestUtils.test_interface(k, x1IO, x2IO, x3IO)
32+
33+
x1IF = KernelFunctions.MOInputIsotopicByFeatures(x1IO.x, out_dim)
34+
x2IF = KernelFunctions.MOInputIsotopicByFeatures(x2IO.x, out_dim)
35+
x3IF = KernelFunctions.MOInputIsotopicByFeatures(x3IO.x, out_dim)
1936

20-
@test kernelmatrix(k, x1, x2) kernelmatrix(k, collect(x1), collect(x2))
21-
@test kernelmatrix(k, x1, x1) kernelmatrix(k, x1)
37+
TestUtils.test_interface(k, x1IF, x2IF, x3IF)
2238

2339
@test string(k) == "Semi-parametric Latent Factor Multi-Output Kernel"
2440
@test repr("text/plain", k) == (
@@ -31,13 +47,13 @@
3147
)
3248

3349
# AD test
34-
function test_slfm(A::AbstractMatrix, x1, x2)
50+
function test_slfm(A::AbstractMatrix, x1IO, x2IO)
3551
k = LatentFactorMOKernel(
3652
[Matern32Kernel(), SqExponentialKernel(), FBMKernel()],
3753
IndependentMOKernel(GaussianKernel()),
3854
A,
3955
)
40-
return k((x1, 1), (x2, 1))
56+
return k((x1IO, 1), (x2IO, 1))
4157
end
4258

4359
k = LatentFactorMOKernel(

0 commit comments

Comments
 (0)