Skip to content

Commit 3d3a6c7

Browse files
willtebbuttwt
and
wt
authored
Implement ternary kerneldiagmatrix (#179)
* Implement ternary kerneldiagmatrix * Bump patch version Co-authored-by: wt <[email protected]>
1 parent b67fecc commit 3d3a6c7

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.2"
3+
version = "0.8.3"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/matrix/kernelmatrix.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ kernelmatrix
1919

2020
"""
2121
kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X; obsdim::Int = 2)
22+
kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X, Y; obsdim::Int = 2)
2223
2324
In place version of [`kerneldiagmatrix`](@ref)
2425
"""
@@ -30,6 +31,11 @@ kerneldiagmatrix!
3031
Calculate the diagonal matrix of `X` with respect to kernel `κ`
3132
`obsdim = 1` means the matrix `X` has size #samples x #dimension
3233
`obsdim = 2` means the matrix `X` has size #dimension x #samples
34+
35+
kerneldiagmatrix(κ::Kernel, X, Y; obsdim::Int = 2)
36+
37+
Calculate the diagonal of `kernelmatrix(κ, X, Y; obsdim)` efficiently. Requires that `X` and
38+
`Y` are the same length.
3339
"""
3440
kerneldiagmatrix
3541

@@ -59,8 +65,16 @@ function kerneldiagmatrix!(K::AbstractVector, κ::Kernel, x::AbstractVector)
5965
return map!(x -> κ(x, x), K, x)
6066
end
6167

68+
function kerneldiagmatrix!(
69+
K::AbstractVector, κ::Kernel, x::AbstractVector, y::AbstractVector,
70+
)
71+
return map!(κ, x, y)
72+
end
73+
6274
kerneldiagmatrix::Kernel, x::AbstractVector) = map(x -> κ(x, x), x)
6375

76+
kerneldiagmatrix::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, y)
77+
6478

6579

6680
#
@@ -99,36 +113,47 @@ end
99113
const defaultobs = 2
100114

101115
function kernelmatrix!(
102-
K::AbstractMatrix, κ::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs
116+
K::AbstractMatrix, κ::Kernel, X::AbstractMatrix; obsdim::Int=defaultobs,
103117
)
104118
return kernelmatrix!(K, κ, vec_of_vecs(X; obsdim=obsdim))
105119
end
106120

107121
function kernelmatrix!(
108122
K::AbstractMatrix, κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix;
109-
obsdim::Int = defaultobs
123+
obsdim::Int=defaultobs,
110124
)
111-
x = vec_of_vecs(X; obsdim=obsdim)
112-
y = vec_of_vecs(Y; obsdim=obsdim)
113-
return kernelmatrix!(K, κ, x, y)
125+
return kernelmatrix!(K, κ, vec_of_vecs(X; obsdim=obsdim), vec_of_vecs(Y; obsdim=obsdim))
114126
end
115127

116-
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
128+
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int=defaultobs)
117129
return kernelmatrix(κ, vec_of_vecs(X; obsdim=obsdim))
118130
end
119131

120132
function kernelmatrix::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim=defaultobs)
121-
x = vec_of_vecs(X; obsdim=obsdim)
122-
y = vec_of_vecs(Y; obsdim=obsdim)
123-
return kernelmatrix(κ, x, y)
133+
return kernelmatrix(κ, vec_of_vecs(X; obsdim=obsdim), vec_of_vecs(Y; obsdim=obsdim))
124134
end
125135

126136
function kerneldiagmatrix!(
127-
K::AbstractVector, κ::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs
137+
K::AbstractVector, κ::Kernel, X::AbstractMatrix; obsdim::Int=defaultobs
128138
)
129139
return kerneldiagmatrix!(K, κ, vec_of_vecs(X; obsdim=obsdim))
130140
end
131141

132-
function kerneldiagmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
142+
function kerneldiagmatrix!(
143+
K::AbstractVector, κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix;
144+
obsdim::Int = defaultobs,
145+
)
146+
return kerneldiagmatrix!(
147+
K, κ, vec_of_vecs(X; obsdim=obsdim), vec_of_vecs(Y; obsdim=obsdim),
148+
)
149+
end
150+
151+
function kerneldiagmatrix::Kernel, X::AbstractMatrix; obsdim::Int=defaultobs)
133152
return kerneldiagmatrix(κ, vec_of_vecs(X; obsdim=obsdim))
134153
end
154+
155+
function kerneldiagmatrix(
156+
κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int=defaultobs,
157+
)
158+
return kerneldiagmatrix(κ, vec_of_vecs(X; obsdim=obsdim), vec_of_vecs(Y; obsdim=obsdim))
159+
end

src/test_utils.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,23 @@ function test_interface(
3434
x2::AbstractVector;
3535
atol=__ATOL,
3636
)
37-
# TODO: uncomment the tests of ternary kerneldiagmatrix.
38-
3937
# Ensure that we have the required inputs.
4038
@assert length(x0) == length(x1)
4139
@assert length(x0) length(x2)
4240

4341
# Check that kerneldiagmatrix basically works.
44-
# @test kerneldiagmatrix(k, x0, x1) isa AbstractVector
45-
# @test length(kerneldiagmatrix(k, x0, x1)) == length(x0)
42+
@test kerneldiagmatrix(k, x0, x1) isa AbstractVector
43+
@test length(kerneldiagmatrix(k, x0, x1)) == length(x0)
4644

4745
# Check that pairwise basically works.
4846
@test kernelmatrix(k, x0, x2) isa AbstractMatrix
4947
@test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2))
5048

5149
# Check that elementwise is consistent with pairwise.
52-
# @test kerneldiagmatrix(k, x0, x1) ≈ diag(kernelmatrix(k, x0, x1)) atol=atol
50+
@test kerneldiagmatrix(k, x0, x1) diag(kernelmatrix(k, x0, x1)) atol=atol
5351

5452
# Check additional binary elementwise properties for kernels.
55-
# @test kerneldiagmatrix(k, x0, x1) ≈ kerneldiagmatrix(k, x1, x0)
53+
@test kerneldiagmatrix(k, x0, x1) kerneldiagmatrix(k, x1, x0)
5654
@test kernelmatrix(k, x0, x2) kernelmatrix(k, x2, x0)' atol=atol
5755

5856
# Check that unary elementwise basically works.
@@ -71,7 +69,7 @@ function test_interface(
7169
@test eigmin(Matrix(kernelmatrix(k, x0))) > -atol
7270

7371
# Check that unary elementwise / pairwise are consistent with the binary versions.
74-
# @test kerneldiagmatrix(k, x0) ≈ kerneldiagmatrix(k, x0, x0) atol=atol
72+
@test kerneldiagmatrix(k, x0) kerneldiagmatrix(k, x0, x0) atol=atol
7573
@test kernelmatrix(k, x0) kernelmatrix(k, x0, x0) atol=atol
7674

7775
# Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`.

0 commit comments

Comments
 (0)