Skip to content

Commit 486ba20

Browse files
authored
Merge branch 'master' into remove-transform
2 parents fe5487e + f63be92 commit 486ba20

File tree

11 files changed

+171
-11
lines changed

11 files changed

+171
-11
lines changed

.github/workflows/TagBot.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: TagBot
2+
on:
3+
schedule:
4+
- cron: 0 * * * *
5+
jobs:
6+
TagBot:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: JuliaRegistries/TagBot@v1
10+
with:
11+
token: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,29 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
12+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1213
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1314
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1415

1516
[compat]
1617
Compat = "2.2, 3"
1718
Distances = "0.8"
18-
PDMats = "0.9"
19+
Requires = "1.0.1"
1920
SpecialFunctions = "0.8, 0.9, 0.10"
21+
StatsBase = "0.32"
2022
StatsFuns = "0.8, 0.9"
2123
ZygoteRules = "0.2"
2224
julia = "1.0"
2325

2426
[extras]
2527
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28+
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
29+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2630
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2731
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2832
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2933

3034
[targets]
31-
test = ["Random", "Test", "FiniteDifferences", "Zygote"]
35+
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ kernelmatrix!
6161
kerneldiagmatrix
6262
kerneldiagmatrix!
6363
kernelpdmat
64+
kernelkronmat
6465
transform
6566
```
6667

docs/src/userguide.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ For example:
2222
kernelmatrix(k,A,obsdim=2) # Return a 5x5 matrix
2323
```
2424

25+
We also support specific kernel matrices outputs:
26+
- For a positive-definite matrix object`PDMat` from [`PDMats.jl`](https://github.com/JuliaStats/PDMats.jl). Call `kernelpdmat(k,A,obsdim=1)`, it will create a matrix and in case of bad conditionning will add some diagonal noise until the matrix is considered PSD, it will then return a `PDMat` object. For this method to work in your code you need to include `using PDMats` first
27+
- For a Kronecker matrix, we rely on [`Kronecker.jl`](https://github.com/MichielStock/Kronecker.jl). We give two methods : `kernelkronmat(k,[x,y,z])` where `x` `y` and `z` are vectors which will return a `KroneckerProduct`, and `kernelkronmat(k,x,dims)` where `x` is a vector and dims and the number of features. Make sure that `k` is a vector compatible with such constructions (with `iskroncompatible`). Both method will return a . For those methods to work in your code you need to include `using Kronecker` first
28+
2529
## Kernel manipulation
2630

2731
One can create combinations of kernels via `KernelSum` and `KernelProduct` or using simple operators `+` and `*`.

src/KernelFunctions.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
44
export transform
55
export params, duplicate, set! # Helpers
66

@@ -16,12 +16,14 @@ export TransformedKernel, ScaledKernel
1616

1717
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1818

19+
export NystromFact, nystrom
20+
1921
using Compat
22+
using Requires
2023
using Distances, LinearAlgebra
2124
using SpecialFunctions: logabsgamma, besselk
2225
using ZygoteRules: @adjoint
2326
using StatsFuns: logtwo
24-
using PDMats: PDMat
2527
using InteractiveUtils: subtypes
2628

2729
const defaultobs = 2
@@ -44,12 +46,17 @@ end
4446
include("kernels/transformedkernel.jl")
4547
include("kernels/scaledkernel.jl")
4648
include("matrix/kernelmatrix.jl")
47-
include("matrix/kernelpdmat.jl")
4849
include("kernels/kernelsum.jl")
4950
include("kernels/kernelproduct.jl")
51+
include("approximations/nystrom.jl")
5052

5153
include("generic.jl")
5254

5355
include("zygote_adjoints.jl")
5456

57+
function __init__()
58+
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
59+
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
60+
end
61+
5562
end

src/approximations/nystrom.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Following the algorithm by William and Seeger, 2001
2+
# Cs is equivalent to X_mm and C to X_mn
3+
4+
function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
5+
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
6+
n = size(X, obsdim)
7+
m = ceil(Int, n*r)
8+
S = StatsBase.sample(1:n, m; replace=false, ordered=true)
9+
return S
10+
end
11+
12+
function nystrom_sample(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs)
13+
obsdim [1, 2] || throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
14+
Xₘ = obsdim == 1 ? X[S, :] : X[:, S]
15+
C = k(Xₘ, X; obsdim=obsdim)
16+
Cs = C[:, S]
17+
return (C, Cs)
18+
end
19+
20+
function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T)*size(Cs,1)) where {T<:Real}
21+
# Compute eigendecomposition of sampled component of K
22+
QΛQᵀ = LinearAlgebra.eigen!(LinearAlgebra.Symmetric(Cs))
23+
24+
# Solve for D = Λ^(-1/2) (pseudo inverse - use tolerance from before factorization)
25+
D = QΛQᵀ.values
26+
λ_tol = maximum(D)*tol
27+
28+
for i in eachindex(D)
29+
@inbounds D[i] = abs(D[i]) <= λ_tol ? zero(T) : one(T)/sqrt(D[i])
30+
end
31+
32+
# Scale eigenvectors by D
33+
Q = QΛQᵀ.vectors
34+
QD = LinearAlgebra.rmul!(Q, LinearAlgebra.Diagonal(D)) # Scales column i of Q by D[i]
35+
36+
# W := (QD)(QD)ᵀ = (QΛQᵀ)^(-1) (pseudo inverse)
37+
W = QD*QD'
38+
39+
# Symmetrize W
40+
return LinearAlgebra.copytri!(W, 'U')
41+
end
42+
43+
@doc raw"""
44+
NystromFact
45+
46+
Type for storing a Nystrom factorization. The factorization contains two fields: `W` and
47+
`C`, two matrices satisfying:
48+
```math
49+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
50+
```
51+
"""
52+
struct NystromFact{T<:Real}
53+
W::Matrix{T}
54+
C::Matrix{T}
55+
end
56+
57+
function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
58+
T = Base.promote_eltypeof(W, C)
59+
return NystromFact(convert(Matrix{T}, W), convert(Matrix{T}, C))
60+
end
61+
62+
@doc raw"""
63+
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
64+
65+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
66+
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
67+
Nystrom factorization satisfying:
68+
```math
69+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
70+
```
71+
"""
72+
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
73+
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
74+
W = nystrom_pinv!(Cs)
75+
return NystromFact(W, C)
76+
end
77+
78+
@doc raw"""
79+
nystrom(k::Kernel, X::Matrix, r::Real; obsdim::Int=defaultobs)
80+
81+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
82+
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
83+
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
84+
```math
85+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
86+
```
87+
"""
88+
function nystrom(k::Kernel, X::AbstractMatrix, r::Real; obsdim::Int=defaultobs)
89+
S = sampleindex(X, r; obsdim=obsdim)
90+
return nystrom(k, X, S; obsdim=obsdim)
91+
end
92+
93+
"""
94+
nystrom(CᵀWC::NystromFact)
95+
96+
Compute the approximate kernel matrix based on the Nystrom factorization.
97+
"""
98+
function kernelmatrix(CᵀWC::NystromFact{<:Real})
99+
W = CᵀWC.W
100+
C = CᵀWC.C
101+
return C'*W*C
102+
end

src/matrix/kernelkroeneckermat.jl renamed to src/matrix/kernelkroneckermat.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
using .Kronecker
2+
3+
export kernelkronmat
4+
15
function kernelkronmat(
26
κ::Kernel,
37
X::AbstractVector,
48
dims::Int
59
)
6-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices (see `iskroncompatible()`)"
10+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see `iskroncompatible()`)"
711
k = kernelmatrix(κ,reshape(X,:,1),obsdim=1)
8-
K = kron()
12+
kronecker(k,dims)
913
end
1014

1115
function kernelkronmat(
1216
κ::Kernel,
1317
X::AbstractVector{<:AbstractVector};
1418
obsdim::Int=defaultobs
1519
)
16-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices"
20+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices"
1721
Ks = kernelmatrix.(κ,X,obsdim=obsdim)
18-
K = kron(Ks)
22+
K = reduce(,Ks)
1923
end
2024

2125

src/matrix/kernelpdmat.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using PDMats: PDMat
2+
3+
export kernelpdmat
4+
15
"""
26
Compute a positive-definite matrix in the form of a `PDMat` matrix see [PDMats.jl]() with the cholesky decomposition precomputed
37
The algorithm recursively tries to add recursively a diagonal nugget until positive definiteness is achieved or that the noise is too big

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Random
55

66
@testset "KernelFunctions" begin
77
include("test_kernelmatrix.jl")
8+
include("test_approximations.jl")
89
include("test_constructors.jl")
910
# include("test_AD.jl")
1011
include("test_transform.jl")

test/test_approximations.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Distances, LinearAlgebra
2+
using Test
3+
using KernelFunctions
4+
5+
dims = [10,5]
6+
X = rand(dims...)
7+
k = SqExponentialKernel()
8+
@testset "Kernel Matrix Approximations" begin
9+
@testset "Nystrom" begin
10+
for obsdim in [1, 2]
11+
@test kernelmatrix(k, X; obsdim=obsdim) kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
12+
@test kernelmatrix(k, X; obsdim=obsdim) kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
13+
end
14+
end
15+
end

test/test_kernelmatrix.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Distances, LinearAlgebra
22
using Test
33
using KernelFunctions
44
using PDMats
5-
5+
using Kronecker
66
dims = [10,5]
77

88
A = rand(dims...)
@@ -84,4 +84,11 @@ kt = transform(SqExponentialKernel(),s)
8484
# @test_throws ErrorException kernelpdmat(k,ones(100,100),obsdim=obsdim)
8585
end
8686
end
87+
@testset "Kronecker" begin
88+
x = range(0,1,length=10)
89+
X = vcat(collect.(Iterators.product(x,x))'...)
90+
@test all(collect(kernelkronmat(k,collect(x),2)).≈kernelmatrix(k,X,obsdim=1))
91+
@test all(collect(kernelkronmat(k,[x,x])).≈kernelmatrix(k,X,obsdim=1))
92+
@test_throws AssertionError kernelkronmat(LinearKernel(),collect(x),2)
93+
end
8794
end

0 commit comments

Comments
 (0)