Skip to content

Commit f63be92

Browse files
authored
Merge pull request #36 from theogf/kronecker
Implemented Kronecker matrices with Kronecker.jl
2 parents 25bf25a + 479cdd7 commit f63be92

File tree

7 files changed

+37
-11
lines changed

7 files changed

+37
-11
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ version = "0.2.4"
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
9+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1111
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1212
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515
[compat]
1616
Compat = "2.2, 3.2"
1717
Distances = "0.8"
18-
PDMats = "0.9"
18+
Requires = "1.0.1"
1919
SpecialFunctions = "0.8, 0.9, 0.10"
2020
StatsBase = "0.32"
2121
StatsFuns = "0.8, 0.9"
@@ -24,9 +24,11 @@ julia = "1.0"
2424

2525
[extras]
2626
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
27+
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
28+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2729
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2830
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2931
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3032

3133
[targets]
32-
test = ["Random", "Test", "FiniteDifferences", "Zygote"]
34+
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ kernelmatrix!
5959
kerneldiagmatrix
6060
kerneldiagmatrix!
6161
kernelpdmat
62+
kernelkronmat
6263
transform
6364
```
6465

docs/src/userguide.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ For example:
2222
kernelmatrix(k,A,obsdim=2) # Return a 5x5 matrix
2323
```
2424

25+
We also support specific kernel matrices outputs:
26+
- For a positive-definite matrix object`PDMat` from [`PDMats.jl`](https://github.com/JuliaStats/PDMats.jl). Call `kernelpdmat(k,A,obsdim=1)`, it will create a matrix and in case of bad conditionning will add some diagonal noise until the matrix is considered PSD, it will then return a `PDMat` object. For this method to work in your code you need to include `using PDMats` first
27+
- For a Kronecker matrix, we rely on [`Kronecker.jl`](https://github.com/MichielStock/Kronecker.jl). We give two methods : `kernelkronmat(k,[x,y,z])` where `x` `y` and `z` are vectors which will return a `KroneckerProduct`, and `kernelkronmat(k,x,dims)` where `x` is a vector and dims and the number of features. Make sure that `k` is a vector compatible with such constructions (with `iskroncompatible`). Both method will return a . For those methods to work in your code you need to include `using Kronecker` first
28+
2529
## Kernel manipulation
2630

2731
One can create combinations of kernels via `KernelSum` and `KernelProduct` or using simple operators `+` and `*`.

src/KernelFunctions.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa # Main matrix functions
44
export params, duplicate, set! # Helpers
55

66
export Kernel
@@ -17,12 +17,12 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransf
1717
export NystromFact, nystrom
1818

1919
using Compat
20+
using Requires
2021
using Distances, LinearAlgebra
2122
using SpecialFunctions: logabsgamma, besselk
2223
using ZygoteRules: @adjoint
2324
using StatsFuns: logtwo
2425
using StatsBase
25-
using PDMats: PDMat
2626

2727
const defaultobs = 2
2828

@@ -41,7 +41,6 @@ for k in ["exponential","matern","polynomial","constant","rationalquad","exponen
4141
include(joinpath("kernels",k*".jl"))
4242
end
4343
include("matrix/kernelmatrix.jl")
44-
include("matrix/kernelpdmat.jl")
4544
include("kernels/kernelsum.jl")
4645
include("kernels/kernelproduct.jl")
4746
include("approximations/nystrom.jl")
@@ -50,4 +49,9 @@ include("generic.jl")
5049

5150
include("zygote_adjoints.jl")
5251

52+
function __init__()
53+
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
54+
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
55+
end
56+
5357
end

src/matrix/kernelkroeneckermat.jl renamed to src/matrix/kernelkroneckermat.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
using .Kronecker
2+
3+
export kernelkronmat
4+
15
function kernelkronmat(
26
κ::Kernel,
37
X::AbstractVector,
48
dims::Int
59
)
6-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices (see `iskroncompatible()`)"
10+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see `iskroncompatible()`)"
711
k = kernelmatrix(κ,reshape(X,:,1),obsdim=1)
8-
K = kron()
12+
kronecker(k,dims)
913
end
1014

1115
function kernelkronmat(
1216
κ::Kernel,
1317
X::AbstractVector{<:AbstractVector};
1418
obsdim::Int=defaultobs
1519
)
16-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices"
20+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices"
1721
Ks = kernelmatrix.(κ,X,obsdim=obsdim)
18-
K = kron(Ks)
22+
K = reduce(,Ks)
1923
end
2024

2125

src/matrix/kernelpdmat.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using PDMats: PDMat
2+
3+
export kernelpdmat
4+
15
"""
26
Compute a positive-definite matrix in the form of a `PDMat` matrix see [PDMats.jl]() with the cholesky decomposition precomputed
37
The algorithm recursively tries to add recursively a diagonal nugget until positive definiteness is achieved or that the noise is too big

test/test_kernelmatrix.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Distances, LinearAlgebra
22
using Test
33
using KernelFunctions
44
using PDMats
5-
5+
using Kronecker
66
dims = [10,5]
77

88
A = rand(dims...)
@@ -67,4 +67,11 @@ k = SqExponentialKernel()
6767
# @test_throws ErrorException kernelpdmat(k,ones(100,100),obsdim=obsdim)
6868
end
6969
end
70+
@testset "Kronecker" begin
71+
x = range(0,1,length=10)
72+
X = vcat(collect.(Iterators.product(x,x))'...)
73+
@test all(collect(kernelkronmat(k,collect(x),2)).≈kernelmatrix(k,X,obsdim=1))
74+
@test all(collect(kernelkronmat(k,[x,x])).≈kernelmatrix(k,X,obsdim=1))
75+
@test_throws AssertionError kernelkronmat(LinearKernel(),collect(x),2)
76+
end
7077
end

0 commit comments

Comments
 (0)