Skip to content

Commit 8b743dd

Browse files
Stable version of KernelProduct and added test_type_stability (#486)
* Stable version of product and test * Cleanup * Update test/kernels/kernelproduct.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump version * Generalize approach to all functions * Refactor TestUtils * Uncomment tests * Update src/TestUtils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/kernels/kernelproduct.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/TestUtils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/TestUtils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix formatting and missing definition in kernel sum * Additional formatting * Moar formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent eac3538 commit 8b743dd

File tree

7 files changed

+134
-89
lines changed

7 files changed

+134
-89
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.48"
3+
version = "0.10.49"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/TestUtils.jl

Lines changed: 100 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,94 @@ function test_interface(
8686
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) kernelmatrix_diag(k, x0, x1)
8787
end
8888

89-
function test_interface(
90-
rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
91-
) where {T<:Real}
92-
return test_interface(
93-
k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...
89+
"""
90+
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}=Float64; kwargs...) where {T}
91+
92+
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
93+
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
94+
95+
For other input types, please provide the data manually.
96+
97+
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
98+
randomly generated inputs.
99+
"""
100+
function test_interface(k::Kernel, T::Type=Float64; kwargs...)
101+
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
102+
end
103+
104+
function test_interface(rng::AbstractRNG, k::Kernel, T::Type=Float64; kwargs...)
105+
return test_with_type(test_interface, rng, k, T; kwargs...)
106+
end
107+
108+
"""
109+
test_type_stability(
110+
k::Kernel,
111+
x0::AbstractVector,
112+
x1::AbstractVector,
113+
x2::AbstractVector,
94114
)
115+
116+
Run type stability checks over `k(x,y)` and the different functions of the API
117+
(`kernelmatrix`, `kernelmatrix_diag`). `x0` and `x1` should be of the same
118+
length with different values, while `x0` and `x2` should be of different lengths.
119+
"""
120+
function test_type_stability(
121+
k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector
122+
)
123+
# Ensure that we have the required inputs.
124+
@assert length(x0) == length(x1)
125+
@assert length(x0) length(x2)
126+
@test @inferred(kernelmatrix(k, x0)) isa AbstractMatrix
127+
@test @inferred(kernelmatrix(k, x0, x2)) isa AbstractMatrix
128+
@test @inferred(kernelmatrix_diag(k, x0)) isa AbstractVector
129+
@test @inferred(kernelmatrix_diag(k, x0, x1)) isa AbstractVector
95130
end
96131

97-
function test_interface(
98-
rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs...
132+
function test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) where {T}
133+
return test_type_stability(Random.GLOBAL_RNG, k, T; kwargs...)
134+
end
135+
136+
function test_type_stability(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}
137+
return test_with_type(test_type_stability, rng, k, T; kwargs...)
138+
end
139+
140+
"""
141+
test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}
142+
143+
Run the functions `f`, (for example [`test_interface`](@ref) or
144+
[`test_type_stable`](@ref)) for randomly generated inputs of types `Vector{T}`,
145+
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
146+
147+
For other input types, please provide the data manually.
148+
149+
The keyword arguments are forwarded to the invocations of `f` with the
150+
randomly generated inputs.
151+
"""
152+
function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}
153+
@testset "Vector{$T}" begin
154+
test_with_type(f, rng, k, Vector{T}; kwargs...)
155+
end
156+
@testset "ColVecs{$T}" begin
157+
test_with_type(f, rng, k, ColVecs{T}; kwargs...)
158+
end
159+
@testset "RowVecs{$T}" begin
160+
test_with_type(f, rng, k, RowVecs{T}; kwargs...)
161+
end
162+
@testset "Vector{Vector{$T}}" begin
163+
test_with_type(f, rng, k, Vector{Vector{T}}; kwargs...)
164+
end
165+
end
166+
167+
function test_with_type(
168+
f, rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
169+
) where {T<:Real}
170+
return f(k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...)
171+
end
172+
173+
function test_with_type(
174+
f, rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs...
99175
) where {T<:Real}
100-
return test_interface(
176+
return f(
101177
k,
102178
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
103179
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
@@ -106,10 +182,10 @@ function test_interface(
106182
)
107183
end
108184

109-
function test_interface(
110-
rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...
185+
function test_with_type(
186+
f, rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...
111187
) where {T<:Real}
112-
return test_interface(
188+
return f(
113189
k,
114190
ColVecs(randn(rng, T, dim_in, 11)),
115191
ColVecs(randn(rng, T, dim_in, 11)),
@@ -118,10 +194,10 @@ function test_interface(
118194
)
119195
end
120196

121-
function test_interface(
122-
rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...
197+
function test_with_type(
198+
f, rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...
123199
) where {T<:Real}
124-
return test_interface(
200+
return f(
125201
k,
126202
RowVecs(randn(rng, T, 11, dim_in)),
127203
RowVecs(randn(rng, T, 11, dim_in)),
@@ -130,10 +206,10 @@ function test_interface(
130206
)
131207
end
132208

133-
function test_interface(
134-
rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
209+
function test_with_type(
210+
f, rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
135211
) where {T<:Real}
136-
return test_interface(
212+
return f(
137213
k,
138214
[randn(rng, T, dim_in) for _ in 1:11],
139215
[randn(rng, T, dim_in) for _ in 1:11],
@@ -142,8 +218,8 @@ function test_interface(
142218
)
143219
end
144220

145-
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...)
146-
return test_interface(
221+
function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...)
222+
return f(
147223
k,
148224
[randstring(rng) for _ in 1:3],
149225
[randstring(rng) for _ in 1:3],
@@ -152,10 +228,10 @@ function test_interface(rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwa
152228
)
153229
end
154230

155-
function test_interface(
156-
rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
231+
function test_with_type(
232+
f, rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
157233
)
158-
return test_interface(
234+
return f(
159235
k,
160236
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
161237
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
@@ -164,38 +240,8 @@ function test_interface(
164240
)
165241
end
166242

167-
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
168-
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
169-
end
170-
171-
"""
172-
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
173-
174-
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
175-
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
176-
177-
For other input types, please provide the data manually.
178-
179-
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
180-
randomly generated inputs.
181-
"""
182-
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
183-
@testset "Vector{$T}" begin
184-
test_interface(rng, k, Vector{T}; kwargs...)
185-
end
186-
@testset "ColVecs{$T}" begin
187-
test_interface(rng, k, ColVecs{T}; kwargs...)
188-
end
189-
@testset "RowVecs{$T}" begin
190-
test_interface(rng, k, RowVecs{T}; kwargs...)
191-
end
192-
@testset "Vector{Vector{T}}" begin
193-
test_interface(rng, k, Vector{Vector{T}}; kwargs...)
194-
end
195-
end
196-
197-
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)
198-
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
243+
function test_with_type(f, k::Kernel, T::Type{<:Real}; kwargs...)
244+
return test_with_type(f, Random.GLOBAL_RNG, k, T; kwargs...)
199245
end
200246

201247
"""

src/kernels/kernelproduct.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,27 @@ Base.length(k::KernelProduct) = length(k.kernels)
4545

4646
::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
4747

48+
function _hadamard(f, ks::Tuple, args...)
49+
return f(first(ks), args...) .* _hadamard(f, Base.tail(ks), args...)
50+
end
51+
_hadamard(f, ks::Tuple{Tx}, args...) where {Tx} = f(only(ks), args...)
52+
53+
::KernelProduct)(x, y) = _hadamard((k, x, y) -> k(x, y), κ.kernels, x, y)
54+
4855
function kernelmatrix::KernelProduct, x::AbstractVector)
49-
return reduce(hadamard, kernelmatrix(k, x) for k in κ.kernels)
56+
return _hadamard(kernelmatrix, κ.kernels, x)
5057
end
5158

5259
function kernelmatrix::KernelProduct, x::AbstractVector, y::AbstractVector)
53-
return reduce(hadamard, kernelmatrix(k, x, y) for k in κ.kernels)
60+
return _hadamard(kernelmatrix, κ.kernels, x, y)
5461
end
5562

5663
function kernelmatrix_diag::KernelProduct, x::AbstractVector)
57-
return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels)
64+
return _hadamard(kernelmatrix_diag, κ.kernels, x)
5865
end
5966

6067
function kernelmatrix_diag::KernelProduct, x::AbstractVector, y::AbstractVector)
61-
return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels)
68+
return _hadamard(kernelmatrix_diag, κ.kernels, x, y)
6269
end
6370

6471
function Base.show(io::IO, κ::KernelProduct)

src/kernels/kernelsum.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,25 @@ end
4343

4444
Base.length(k::KernelSum) = length(k.kernels)
4545

46-
_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))
47-
_sum(f::Tf, x::Tuple{Tx}) where {Tf,Tx} = f(x[1])
46+
_sum(f, ks::Tuple, args...) = f(first(ks), args...) + _sum(f, Base.tail(ks), args...)
47+
_sum(f, ks::Tuple{Tx}, args...) where {Tx} = f(only(ks), args...)
4848

49-
::KernelSum)(x, y) = _sum(k -> k(x, y), κ.kernels)
49+
::KernelSum)(x, y) = _sum((k, x, y) -> k(x, y), κ.kernels, x, y)
5050

5151
function kernelmatrix::KernelSum, x::AbstractVector)
52-
return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
52+
return _sum(kernelmatrix, κ.kernels, x)
5353
end
5454

5555
function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
56-
return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
56+
return _sum(kernelmatrix, κ.kernels, x, y)
5757
end
5858

5959
function kernelmatrix_diag::KernelSum, x::AbstractVector)
60-
return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
60+
return _sum(kernelmatrix_diag, κ.kernels, x)
6161
end
6262

6363
function kernelmatrix_diag::KernelSum, x::AbstractVector, y::AbstractVector)
64-
return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
64+
return _sum(kernelmatrix_diag, κ.kernels, x, y)
6565
end
6666

6767
function Base.show(io::IO, κ::KernelSum)

test/kernels/kernelproduct.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
)
1111

1212
# Standardised tests.
13-
TestUtils.test_interface(k, Float64)
14-
TestUtils.test_interface(ConstantKernel(; c=1.0) * WhiteKernel(), Vector{String})
13+
test_interface(k, Float64)
14+
test_interface(ConstantKernel(; c=1.0) * WhiteKernel(), Vector{String})
1515
test_ADs(
1616
x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1)
1717
)
1818
test_interface_ad_perf(2.4, StableRNG(123456)) do c
1919
KernelProduct(SqExponentialKernel(), LinearKernel(; c=c))
2020
end
2121
test_params(k1 * k2, (k1, k2))
22+
23+
nested_k = RBFKernel() * (LinearKernel() + CosineKernel() * RBFKernel())
24+
test_type_stability(nested_k)
2225
end

test/kernels/kernelsum.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
k = KernelSum(k1, k2)
55
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
66
@test length(k) == 2
7-
@test string(k) == (
7+
@test repr(k) == (
88
"Sum of 2 kernels:\n" *
99
"\tLinear Kernel (c = 0.0)\n" *
1010
"\tSquared Exponential Kernel (metric = Euclidean(0.0))"
1111
)
1212

1313
# Standardised tests.
14-
TestUtils.test_interface(k, Float64)
15-
TestUtils.test_interface(ConstantKernel(; c=1.5) * WhiteKernel(), Vector{String})
14+
test_interface(k, Float64)
15+
test_interface(ConstantKernel(; c=1.5) + WhiteKernel(), Vector{String})
1616
test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1))
1717
test_interface_ad_perf(2.4, StableRNG(123456)) do c
1818
KernelSum(SqExponentialKernel(), LinearKernel(; c=c))
@@ -21,22 +21,11 @@
2121
test_params(k1 + k2, (k1, k2))
2222

2323
# Regression tests for https://github.com//issues/458
24-
@testset "Type stability" begin
25-
function check_type_stability(k)
26-
@test (@inferred k(0.1, 0.2)) isa Real
27-
x = rand(10)
28-
y = rand(10)
29-
@test (@inferred kernelmatrix(k, x)) isa Matrix{<:Real}
30-
@test (@inferred kernelmatrix(k, x, y)) isa Matrix{<:Real}
31-
@test (@inferred kernelmatrix_diag(k, x)) isa Vector{<:Real}
32-
@test (@inferred kernelmatrix_diag(k, x, y)) isa Vector{<:Real}
33-
end
34-
@testset for k in (
35-
RBFKernel() + RBFKernel() * LinearKernel(),
36-
RBFKernel() + RBFKernel() * ExponentialKernel(),
37-
RBFKernel() * (LinearKernel() + ExponentialKernel()),
38-
)
39-
check_type_stability(k)
40-
end
24+
@testset for k in (
25+
RBFKernel() + RBFKernel() * LinearKernel(),
26+
RBFKernel() + RBFKernel() * ExponentialKernel(),
27+
RBFKernel() * (LinearKernel() + ExponentialKernel()),
28+
)
29+
test_type_stability(k)
4130
end
4231
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using Compat: only
2020

2121
using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils
2222

23-
using KernelFunctions.TestUtils: test_interface, example_inputs
23+
using KernelFunctions.TestUtils: test_interface, test_type_stability, example_inputs
2424

2525
# The GROUP is used to run different sets of tests in parallel on the GitHub Actions CI.
2626
# If you want to introduce a new group, ensure you also add it to .github/workflows/ci.yml

0 commit comments

Comments
 (0)