Skip to content

Stable version of KernelProduct and added test_type_stability #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 1, 2022
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.48"
version = "0.10.49"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
154 changes: 100 additions & 54 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,94 @@ function test_interface(
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1)
end

function test_interface(
rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
) where {T<:Real}
return test_interface(
k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...
"""
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}=Float64; kwargs...) where {T}

Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
randomly generated inputs.
"""
function test_interface(k::Kernel, T::Type=Float64; kwargs...)
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
end

function test_interface(rng::AbstractRNG, k::Kernel, T::Type=Float64; kwargs...)
return test_with_type(test_interface, rng, k, T; kwargs...)
end

"""
test_type_stability(
k::Kernel,
x0::AbstractVector,
x1::AbstractVector,
x2::AbstractVector,
)

Run type stability checks over `k(x,y)` and the different functions of the API
(`kernelmatrix`, `kernelmatrix_diag`). `x0` and `x1` should be of the same
length with different values, while `x0` and `x2` should be of different lengths.
"""
function test_type_stability(
k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector
)
# Ensure that we have the required inputs.
@assert length(x0) == length(x1)
@assert length(x0) ≠ length(x2)
@test @inferred(kernelmatrix(k, x0)) isa AbstractMatrix
@test @inferred(kernelmatrix(k, x0, x2)) isa AbstractMatrix
@test @inferred(kernelmatrix_diag(k, x0)) isa AbstractVector
@test @inferred(kernelmatrix_diag(k, x0, x1)) isa AbstractVector
end

function test_interface(
rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs...
function test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) where {T}
return test_type_stability(Random.GLOBAL_RNG, k, T; kwargs...)
end

function test_type_stability(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}
return test_with_type(test_type_stability, rng, k, T; kwargs...)
end

"""
test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}

Run the functions `f`, (for example [`test_interface`](@ref) or
[`test_type_stable`](@ref)) for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of `f` with the
randomly generated inputs.
"""
function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}
@testset "Vector{$T}" begin
test_with_type(f, rng, k, Vector{T}; kwargs...)
end
@testset "ColVecs{$T}" begin
test_with_type(f, rng, k, ColVecs{T}; kwargs...)
end
@testset "RowVecs{$T}" begin
test_with_type(f, rng, k, RowVecs{T}; kwargs...)
end
@testset "Vector{Vector{$T}}" begin
test_with_type(f, rng, k, Vector{Vector{T}}; kwargs...)
end
end

function test_with_type(
f, rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
) where {T<:Real}
return f(k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...)
end

function test_with_type(
f, rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs...
) where {T<:Real}
return test_interface(
return f(
k,
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
Expand All @@ -106,10 +182,10 @@ function test_interface(
)
end

function test_interface(
rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...
function test_with_type(
f, rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...
) where {T<:Real}
return test_interface(
return f(
k,
ColVecs(randn(rng, T, dim_in, 11)),
ColVecs(randn(rng, T, dim_in, 11)),
Expand All @@ -118,10 +194,10 @@ function test_interface(
)
end

function test_interface(
rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...
function test_with_type(
f, rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...
) where {T<:Real}
return test_interface(
return f(
k,
RowVecs(randn(rng, T, 11, dim_in)),
RowVecs(randn(rng, T, 11, dim_in)),
Expand All @@ -130,10 +206,10 @@ function test_interface(
)
end

function test_interface(
rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
function test_with_type(
f, rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
) where {T<:Real}
return test_interface(
return f(
k,
[randn(rng, T, dim_in) for _ in 1:11],
[randn(rng, T, dim_in) for _ in 1:11],
Expand All @@ -142,8 +218,8 @@ function test_interface(
)
end

function test_interface(rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...)
return test_interface(
function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...)
return f(
k,
[randstring(rng) for _ in 1:3],
[randstring(rng) for _ in 1:3],
Expand All @@ -152,10 +228,10 @@ function test_interface(rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwa
)
end

function test_interface(
rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
function test_with_type(
f, rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
)
return test_interface(
return f(
k,
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
Expand All @@ -164,38 +240,8 @@ function test_interface(
)
end

function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
end

"""
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real}

Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
randomly generated inputs.
"""
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
@testset "Vector{$T}" begin
test_interface(rng, k, Vector{T}; kwargs...)
end
@testset "ColVecs{$T}" begin
test_interface(rng, k, ColVecs{T}; kwargs...)
end
@testset "RowVecs{$T}" begin
test_interface(rng, k, RowVecs{T}; kwargs...)
end
@testset "Vector{Vector{T}}" begin
test_interface(rng, k, Vector{Vector{T}}; kwargs...)
end
end

function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
function test_with_type(f, k::Kernel, T::Type{<:Real}; kwargs...)
return test_with_type(f, Random.GLOBAL_RNG, k, T; kwargs...)
end

"""
Expand Down
15 changes: 11 additions & 4 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,27 @@ Base.length(k::KernelProduct) = length(k.kernels)

(κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)

function _hadamard(f, ks::Tuple, args...)
return f(first(ks), args...) .* _hadamard(f, Base.tail(ks), args...)
end
_hadamard(f, ks::Tuple{Tx}, args...) where {Tx} = f(only(ks), args...)

(κ::KernelProduct)(x, y) = _hadamard((k, x, y) -> k(x, y), κ.kernels, x, y)

function kernelmatrix(κ::KernelProduct, x::AbstractVector)
return reduce(hadamard, kernelmatrix(k, x) for k in κ.kernels)
return _hadamard(kernelmatrix, κ.kernels, x)
end

function kernelmatrix(κ::KernelProduct, x::AbstractVector, y::AbstractVector)
return reduce(hadamard, kernelmatrix(k, x, y) for k in κ.kernels)
return _hadamard(kernelmatrix, κ.kernels, x, y)
end

function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels)
return _hadamard(kernelmatrix_diag, κ.kernels, x)
end

function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector, y::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels)
return _hadamard(kernelmatrix_diag, κ.kernels, x, y)
end

function Base.show(io::IO, κ::KernelProduct)
Expand Down
14 changes: 7 additions & 7 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,25 @@ end

Base.length(k::KernelSum) = length(k.kernels)

_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))
_sum(f::Tf, x::Tuple{Tx}) where {Tf,Tx} = f(x[1])
_sum(f, ks::Tuple, args...) = f(first(ks), args...) + _sum(f, Base.tail(ks), args...)
_sum(f, ks::Tuple{Tx}, args...) where {Tx} = f(only(ks), args...)

(κ::KernelSum)(x, y) = _sum(k -> k(x, y), κ.kernels)
(κ::KernelSum)(x, y) = _sum((k, x, y) -> k(x, y), κ.kernels, x, y)

function kernelmatrix(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
return _sum(kernelmatrix, κ.kernels, x)
end

function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
return _sum(kernelmatrix, κ.kernels, x, y)
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
return _sum(kernelmatrix_diag, κ.kernels, x)
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
return _sum(kernelmatrix_diag, κ.kernels, x, y)
end

function Base.show(io::IO, κ::KernelSum)
Expand Down
7 changes: 5 additions & 2 deletions test/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
)

# Standardised tests.
TestUtils.test_interface(k, Float64)
TestUtils.test_interface(ConstantKernel(; c=1.0) * WhiteKernel(), Vector{String})
test_interface(k, Float64)
test_interface(ConstantKernel(; c=1.0) * WhiteKernel(), Vector{String})
test_ADs(
x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1)
)
test_interface_ad_perf(2.4, StableRNG(123456)) do c
KernelProduct(SqExponentialKernel(), LinearKernel(; c=c))
end
test_params(k1 * k2, (k1, k2))

nested_k = RBFKernel() * (LinearKernel() + CosineKernel() * RBFKernel())
test_type_stability(nested_k)
end
29 changes: 9 additions & 20 deletions test/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
k = KernelSum(k1, k2)
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
@test length(k) == 2
@test string(k) == (
@test repr(k) == (
"Sum of 2 kernels:\n" *
"\tLinear Kernel (c = 0.0)\n" *
"\tSquared Exponential Kernel (metric = Euclidean(0.0))"
)

# Standardised tests.
TestUtils.test_interface(k, Float64)
TestUtils.test_interface(ConstantKernel(; c=1.5) * WhiteKernel(), Vector{String})
test_interface(k, Float64)
test_interface(ConstantKernel(; c=1.5) + WhiteKernel(), Vector{String})
test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1))
test_interface_ad_perf(2.4, StableRNG(123456)) do c
KernelSum(SqExponentialKernel(), LinearKernel(; c=c))
Expand All @@ -21,22 +21,11 @@
test_params(k1 + k2, (k1, k2))

# Regression tests for https://github.com//issues/458
@testset "Type stability" begin
function check_type_stability(k)
@test (@inferred k(0.1, 0.2)) isa Real
x = rand(10)
y = rand(10)
@test (@inferred kernelmatrix(k, x)) isa Matrix{<:Real}
@test (@inferred kernelmatrix(k, x, y)) isa Matrix{<:Real}
@test (@inferred kernelmatrix_diag(k, x)) isa Vector{<:Real}
@test (@inferred kernelmatrix_diag(k, x, y)) isa Vector{<:Real}
end
@testset for k in (
RBFKernel() + RBFKernel() * LinearKernel(),
RBFKernel() + RBFKernel() * ExponentialKernel(),
RBFKernel() * (LinearKernel() + ExponentialKernel()),
)
check_type_stability(k)
end
@testset for k in (
RBFKernel() + RBFKernel() * LinearKernel(),
RBFKernel() + RBFKernel() * ExponentialKernel(),
RBFKernel() * (LinearKernel() + ExponentialKernel()),
)
test_type_stability(k)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using Compat: only

using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils

using KernelFunctions.TestUtils: test_interface, example_inputs
using KernelFunctions.TestUtils: test_interface, test_type_stability, example_inputs

# The GROUP is used to run different sets of tests in parallel on the GitHub Actions CI.
# If you want to introduce a new group, ensure you also add it to .github/workflows/ci.yml
Expand Down