Skip to content

Commit f9bbd84

Browse files
st--github-actions[bot]devmotiontheogf
authored
make nystrom work with AbstractVector (#427)
* make nystrom work with AbstractVector * add test * Update test/approximations/nystrom.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * patch bump * Update test/approximations/nystrom.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review * deprecate * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: Théo Galy-Fajou <[email protected]> * Update src/approximations/nystrom.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Update src/approximations/nystrom.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Théo Galy-Fajou <[email protected]>
1 parent d1c68a9 commit f9bbd84

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.27"
3+
version = "0.10.28"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/approximations/nystrom.jl

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
# Following the algorithm by William and Seeger, 2001
22
# Cs is equivalent to X_mm and C to X_mn
33

4-
function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
4+
function sampleindex(X::AbstractVector, r::Real)
55
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
6-
n = size(X, obsdim)
6+
n = length(X)
77
m = ceil(Int, n * r)
88
S = StatsBase.sample(1:n, m; replace=false, ordered=true)
99
return S
1010
end
1111

12-
function nystrom_sample(
13-
k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs
14-
)
15-
obsdim [1, 2] ||
16-
throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
17-
Xₘ = obsdim == 1 ? X[S, :] : X[:, S]
18-
C = kernelmatrix(k, Xₘ, X; obsdim=obsdim)
12+
@deprecate sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs) sampleindex(
13+
vec_of_vecs(X; obsdim=obsdim), r
14+
) false
15+
16+
function nystrom_sample(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
17+
Xₘ = @view X[S]
18+
C = kernelmatrix(k, Xₘ, X)
1919
Cs = C[:, S]
2020
return (C, Cs)
2121
end
2222

23+
@deprecate nystrom_sample(
24+
k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs
25+
) nystrom_sample(k, vec_of_vecs(X; obsdim=obsdim), S) false
26+
2327
function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T) * size(Cs, 1)) where {T<:Real}
2428
# Compute eigendecomposition of sampled component of K
2529
QΛQᵀ = LinearAlgebra.eigen!(LinearAlgebra.Symmetric(Cs))
@@ -63,38 +67,48 @@ function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
6367
end
6468

6569
@doc raw"""
66-
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
70+
nystrom(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
6771
68-
Computes a factorization of Nystrom approximation of the square kernel matrix of data
69-
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
70-
Nystrom factorization satisfying:
72+
Compute a factorization of a Nystrom approximation of the square kernel matrix
73+
of data vector `X` with respect to kernel `k`, using indices `S`.
74+
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
7175
```math
7276
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
7377
```
7478
"""
75-
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
76-
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
79+
function nystrom(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
80+
C, Cs = nystrom_sample(k, X, S)
7781
W = nystrom_pinv!(Cs)
7882
return NystromFact(W, C)
7983
end
8084

8185
@doc raw"""
82-
nystrom(k::Kernel, X::Matrix, r::Real; obsdim::Int=defaultobs)
86+
nystrom(k::Kernel, X::AbstractVector, r::Real)
8387
84-
Computes a factorization of Nystrom approximation of the square kernel matrix of data
85-
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
88+
Compute a factorization of a Nystrom approximation of the square kernel matrix
89+
of data vector `X` with respect to kernel `k` using a sample ratio of `r`.
8690
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
8791
```math
8892
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
8993
```
9094
"""
95+
function nystrom(k::Kernel, X::AbstractVector, r::Real)
96+
S = sampleindex(X, r)
97+
return nystrom(k, X, S)
98+
end
99+
100+
function nystrom(
101+
k::Kernel, X::AbstractMatrix, S::AbstractVector{<:Integer}; obsdim::Int=defaultobs
102+
)
103+
return nystrom(k, vec_of_vecs(X; obsdim=obsdim), S)
104+
end
105+
91106
function nystrom(k::Kernel, X::AbstractMatrix, r::Real; obsdim::Int=defaultobs)
92-
S = sampleindex(X, r; obsdim=obsdim)
93-
return nystrom(k, X, S; obsdim=obsdim)
107+
return nystrom(k, vec_of_vecs(X; obsdim=obsdim), r)
94108
end
95109

96110
"""
97-
nystrom(CᵀWC::NystromFact)
111+
kernelmatrix(CᵀWC::NystromFact)
98112
99113
Compute the approximate kernel matrix based on the Nystrom factorization.
100114
"""

test/approximations/nystrom.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
dims = [10, 5]
33
X = rand(dims...)
44
k = SqExponentialKernel()
5+
for obsdim in [1, 2]
6+
Xv = vec_of_vecs(X; obsdim=obsdim)
7+
@assert Xv isa Union{ColVecs,RowVecs}
8+
@test kernelmatrix(k, Xv) kernelmatrix(nystrom(k, Xv, 1.0))
9+
@test kernelmatrix(k, Xv) kernelmatrix(nystrom(k, Xv, collect(1:dims[obsdim])))
10+
end
511
for obsdim in [1, 2]
612
@test kernelmatrix(k, X; obsdim=obsdim)
7-
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
13+
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
814
@test kernelmatrix(k, X; obsdim=obsdim)
9-
kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
15+
kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
1016
end
1117
end

0 commit comments

Comments
 (0)