Skip to content

Commit c76b27d

Browse files
authored
Add lazy kronecker product for matrix kernels, if Kronecker.jl is loaded (#364)
* Restore additions * Improvements for lazy kron * Remove unneeded lines * Small experiment with overwriting * Reorder and overwrite * Format and kernelmatrix! * Reinstate separate method * Adding tests * Duplicate code for readability * Format * Remove comment and patch bump * Change kernelmatrix! * Change to output covariance type * Change to output covariance type - revert * Revert "Change to output covariance type - revert" This reverts commit 09cd20e. * Revert "Change kernelmatrix!" This reverts commit f46bd61. * Add kernelmatrix! changes again * Change input types for pairwise pullback * Missing changes to Any
1 parent 671f960 commit c76b27d

File tree

8 files changed

+141
-40
lines changed

8 files changed

+141
-40
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.17"
3+
version = "0.10.18"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ include("kernels/neuralkernelnetwork.jl")
105105
include("approximations/nystrom.jl")
106106
include("generic.jl")
107107

108-
include("mokernels/mokernel.jl")
109108
include("mokernels/moinput.jl")
109+
include("mokernels/mokernel.jl")
110110
include("mokernels/independent.jl")
111111
include("mokernels/slfm.jl")
112112
include("mokernels/intrinsiccoregion.jl")

src/chainrules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function ChainRulesCore.rrule(
2626
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
2727
)
2828
P = Distances.pairwise(d, X, Y; dims=dims)
29-
function pairwise_pullback(::AbstractMatrix)
29+
function pairwise_pullback(::Any)
3030
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
3131
end
3232
return P, pairwise_pullback
@@ -36,7 +36,7 @@ function ChainRulesCore.rrule(
3636
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
3737
)
3838
P = Distances.pairwise(d, X; dims=dims)
39-
function pairwise_pullback(::AbstractMatrix)
39+
function pairwise_pullback(::Any)
4040
return NoTangent(), NoTangent(), ZeroTangent()
4141
end
4242
return P, pairwise_pullback
@@ -46,7 +46,7 @@ function ChainRulesCore.rrule(
4646
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
4747
)
4848
C = Distances.colwise(d, X, Y)
49-
function colwise_pullback(::AbstractVector)
49+
function colwise_pullback(::Any)
5050
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
5151
end
5252
return C, colwise_pullback
@@ -70,7 +70,7 @@ function ChainRulesCore.rrule(
7070
dims=2,
7171
)
7272
P = Distances.pairwise(d, X, Y; dims=dims)
73-
function pairwise_pullback_cols::AbstractMatrix)
73+
function pairwise_pullback_cols::Any)
7474
if dims == 1
7575
return NoTangent(), NoTangent(), Δ * Y, Δ' * X
7676
else
@@ -84,7 +84,7 @@ function ChainRulesCore.rrule(
8484
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
8585
)
8686
P = Distances.pairwise(d, X; dims=dims)
87-
function pairwise_pullback_cols::AbstractMatrix)
87+
function pairwise_pullback_cols::Any)
8888
if dims == 1
8989
return NoTangent(), NoTangent(), 2 * Δ * X
9090
else
@@ -98,7 +98,7 @@ function ChainRulesCore.rrule(
9898
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
9999
)
100100
C = Distances.colwise(d, X, Y)
101-
function colwise_pullback::AbstractVector)
101+
function colwise_pullback::Any)
102102
return NoTangent(), NoTangent(), Δ' .* Y, Δ' .* X
103103
end
104104
return C, colwise_pullback

src/matrix/kernelkroneckermat.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using .Kronecker: Kronecker
55

66
export kernelkronmat
7+
export kronecker_kernelmatrix
78

89
function kernelkronmat::Kernel, X::AbstractVector, dims::Int)
910
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))"
@@ -25,3 +26,42 @@ end
2526
k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ)
2627
"""
2728
@inline iskroncompatible::Kernel) = false # Default return for kernels
29+
30+
function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
31+
return Kronecker.kronecker(Kfeatures, Koutputs)
32+
end
33+
34+
function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
35+
return Kronecker.kronecker(Koutputs, Kfeatures)
36+
end
37+
38+
function kronecker_kernelmatrix(
39+
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
40+
x::IsotopicMOInputsUnion,
41+
y::IsotopicMOInputsUnion,
42+
)
43+
@assert x.out_dim == y.out_dim
44+
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
45+
Koutputs = _mo_output_covariance(k, x.out_dim)
46+
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
47+
end
48+
49+
function kronecker_kernelmatrix(
50+
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
51+
)
52+
Kfeatures = kernelmatrix(k.kernel, x.x)
53+
Koutputs = _mo_output_covariance(k, x.out_dim)
54+
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
55+
end
56+
57+
function kronecker_kernelmatrix(
58+
k::MOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
59+
)
60+
return throw(
61+
ArgumentError("This kernel does not support a lazy kronecker kernelmatrix.")
62+
)
63+
end
64+
65+
function kronecker_kernelmatrix(k::MOKernel, x::IsotopicMOInputsUnion)
66+
return kronecker_kernelmatrix(k, x, x)
67+
end

src/mokernels/independent.jl

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,28 @@ function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,I
2727
return κ.kernel(x, y) * (px == py)
2828
end
2929

30-
function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, B)
31-
return kron(Kfeatures, B)
32-
end
33-
34-
function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, B)
35-
return kron(B, Kfeatures)
36-
end
30+
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)
3731

3832
function kernelmatrix(
39-
k::IndependentMOKernel, x::MOI, y::MOI
40-
) where {MOI<:IsotopicMOInputsUnion}
33+
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
34+
)
4135
@assert x.out_dim == y.out_dim
4236
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
43-
mtype = eltype(Kfeatures)
44-
return _kernelmatrix_kron_helper(x, Kfeatures, Eye{mtype}(x.out_dim))
37+
Koutputs = _mo_output_covariance(k, x.out_dim)
38+
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
4539
end
4640

4741
if VERSION >= v"1.6"
48-
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, B)
49-
return kron!(K, Kfeatures, B)
50-
end
51-
52-
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, B)
53-
return kron!(K, B, Kfeatures)
54-
end
55-
5642
function kernelmatrix!(
57-
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
58-
) where {MOI<:IsotopicMOInputsUnion}
43+
K::AbstractMatrix,
44+
k::IndependentMOKernel,
45+
x::IsotopicMOInputsUnion,
46+
y::IsotopicMOInputsUnion,
47+
)
5948
@assert x.out_dim == y.out_dim
60-
Ktmp = kernelmatrix(k.kernel, x.x, y.x)
61-
mtype = eltype(Ktmp)
62-
return _kernelmatrix_kron_helper!(
63-
K, x, Ktmp, Matrix{mtype}(I, x.out_dim, x.out_dim)
64-
)
49+
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
50+
Koutputs = _mo_output_covariance(k, x.out_dim)
51+
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
6552
end
6653
end
6754

src/mokernels/intrinsiccoregion.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,31 @@ function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{
4242
return k.B[px, py] * k.kernel(x, y)
4343
end
4444

45+
function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
46+
@assert size(k.B) == (out_dim, out_dim)
47+
return k.B
48+
end
49+
4550
function kernelmatrix(
46-
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
47-
) where {MOI<:IsotopicMOInputsUnion}
51+
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
52+
)
4853
@assert x.out_dim == y.out_dim
4954
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
50-
return _kernelmatrix_kron_helper(x, Kfeatures, k.B)
55+
Koutputs = _mo_output_covariance(k, x.out_dim)
56+
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
5157
end
5258

5359
if VERSION >= v"1.6"
5460
function kernelmatrix!(
55-
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
56-
) where {MOI<:IsotopicMOInputsUnion}
61+
K::AbstractMatrix,
62+
k::IntrinsicCoregionMOKernel,
63+
x::IsotopicMOInputsUnion,
64+
y::IsotopicMOInputsUnion,
65+
)
5766
@assert x.out_dim == y.out_dim
5867
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
59-
return _kernelmatrix_kron_helper!(K, x, Kfeatures, k.B)
68+
Koutputs = _mo_output_covariance(k, x.out_dim)
69+
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
6070
end
6171
end
6272

src/mokernels/mokernel.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,21 @@
44
Abstract type for kernels with multiple outpus.
55
"""
66
abstract type MOKernel <: Kernel end
7+
8+
function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
9+
return kron(Kfeatures, Koutputs)
10+
end
11+
12+
function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
13+
return kron(Koutputs, Kfeatures)
14+
end
15+
16+
if VERSION >= v"1.6"
17+
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
18+
return kron!(K, Kfeatures, Koutputs)
19+
end
20+
21+
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
22+
return kron!(K, Koutputs, Kfeatures)
23+
end
24+
end

test/matrix/kernelkroneckermat.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,50 @@
77
@test all(collect(kernelkronmat(k, collect(x), 2)) .≈ kernelmatrix(k, X; obsdim=1))
88
@test all(collect(kernelkronmat(k, [x, x])) .≈ kernelmatrix(k, X; obsdim=1))
99
@test_throws AssertionError kernelkronmat(LinearKernel(), collect(x), 2)
10+
11+
@testset "lazy kernelmatrix" begin
12+
rng = MersenneTwister(123)
13+
14+
dims = (in=3, out=2, obs=3)
15+
r = 1
16+
17+
A = randn(dims.out, r)
18+
B = A * transpose(A) + Diagonal(rand(dims.out))
19+
20+
# XIF = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)]
21+
x = [rand(dims.in) for _ in 1:2]
22+
XIF = KernelFunctions.MOInputIsotopicByFeatures(x, dims.out)
23+
XIO = KernelFunctions.MOInputIsotopicByOutputs(x, dims.out)
24+
y = [rand(dims.in) for _ in 1:2]
25+
YIF = KernelFunctions.MOInputIsotopicByFeatures(y, dims.out)
26+
YIO = KernelFunctions.MOInputIsotopicByOutputs(y, dims.out)
27+
28+
skernel = GaussianKernel()
29+
kIndMO = IndependentMOKernel(skernel)
30+
31+
A = randn(dims.out, r)
32+
B = A * transpose(A) + Diagonal(rand(dims.out))
33+
icoregionkernel = IntrinsicCoregionMOKernel(skernel, B)
34+
35+
function test_kronecker_kernelmatrix(k, x)
36+
res = kronecker_kernelmatrix(k, x)
37+
@test typeof(res) <: Kronecker.KroneckerProduct
38+
@test res == kernelmatrix(k, x)
39+
end
40+
function test_kronecker_kernelmatrix(k, x, y)
41+
res = kronecker_kernelmatrix(k, x, y)
42+
@test typeof(res) <: Kronecker.KroneckerProduct
43+
@test res == kernelmatrix(k, x, y)
44+
end
45+
46+
for k in [kIndMO, icoregionkernel], x in [XIF, XIO]
47+
test_kronecker_kernelmatrix(k, x)
48+
end
49+
for k in [kIndMO, icoregionkernel], (x, y) in ([XIF, YIF], [XIO, YIO])
50+
test_kronecker_kernelmatrix(k, x, y)
51+
end
52+
53+
struct TestMOKernel <: MOKernel end
54+
@test_throws ArgumentError kronecker_kernelmatrix(TestMOKernel(), XIF)
55+
end
1056
end

0 commit comments

Comments
 (0)