-
Notifications
You must be signed in to change notification settings - Fork 36
test utils revamp #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test utils revamp #159
Changes from 40 commits
e322d2c
81b2bb8
c795cdb
24b0422
32877de
a32eae6
5815b41
14178ce
7ab9c52
180c934
4954575
a79de46
4ba9c35
ce11e1d
1086241
1c8216d
e73b23d
22295f6
a945ab6
463d1ea
523313d
b58c649
2e508e9
e821f1a
09efe1a
7eaae64
13772f9
7a7fdf1
c8965ac
cc559eb
8c079a5
dbd0c16
d250375
efb18be
93ec40d
c7f2490
bc345f6
ab492c0
e2bb5b5
ef044ea
e63002e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,48 +1,62 @@ | ||||||
""" | ||||||
RationalQuadraticKernel(; α = 2.0) | ||||||
RationalQuadraticKernel(; α=2.0) | ||||||
|
||||||
The rational-quadratic kernel is a Mercer kernel given by the formula: | ||||||
``` | ||||||
κ(x,y)=(1+||x−y||²/α)^(-α) | ||||||
κ(x, y) = (1 + ||x − y||² / (2α))^(-α) | ||||||
``` | ||||||
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization. | ||||||
where `α` is a shape parameter of the Euclidean distance. Check | ||||||
[`GammaRationalQuadraticKernel`](@ref) for a generalization. | ||||||
""" | ||||||
struct RationalQuadraticKernel{Tα<:Real} <: SimpleKernel | ||||||
α::Vector{Tα} | ||||||
function RationalQuadraticKernel(;alpha::T=2.0, α::T=alpha) where {T} | ||||||
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1") | ||||||
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 0") | ||||||
return new{T}([α]) | ||||||
end | ||||||
end | ||||||
|
||||||
@functor RationalQuadraticKernel | ||||||
|
||||||
kappa(κ::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²/first(κ.α))^(-first(κ.α)) | ||||||
function kappa(κ::RationalQuadraticKernel, d²::T) where {T<:Real} | ||||||
return (one(T) + d² / (2 * first(κ.α)))^(-first(κ.α)) | ||||||
end | ||||||
|
||||||
metric(::RationalQuadraticKernel) = SqEuclidean() | ||||||
|
||||||
Base.show(io::IO, κ::RationalQuadraticKernel) = print(io, "Rational Quadratic Kernel (α = ", first(κ.α), ")") | ||||||
function Base.show(io::IO, κ::RationalQuadraticKernel) | ||||||
print(io, "Rational Quadratic Kernel (α = $(first(κ.α)))") | ||||||
end | ||||||
|
||||||
""" | ||||||
`GammaRationalQuadraticKernel([ρ=1.0[,α=2.0[,γ=2.0]]])` | ||||||
`GammaRationalQuadraticKernel([α=2.0 [, γ=2.0]])` | ||||||
|
||||||
The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the formula: | ||||||
``` | ||||||
κ(x,y)=(1+ρ^(2γ)||x−y||^(2γ)/α)^(-α) | ||||||
κ(x, y) = (1 + ||x−y||^γ / α)^(-α) | ||||||
``` | ||||||
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter. | ||||||
""" | ||||||
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: SimpleKernel | ||||||
α::Vector{Tα} | ||||||
γ::Vector{Tγ} | ||||||
function GammaRationalQuadraticKernel(;alpha::Tα=2.0, gamma::Tγ=2.0, α::Tα=alpha, γ::Tγ=gamma) where {Tα<:Real, Tγ<:Real} | ||||||
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1") | ||||||
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1") | ||||||
function GammaRationalQuadraticKernel( | ||||||
;alpha::Tα=2.0, gamma::Tγ=2.0, α::Tα=alpha, γ::Tγ=gamma, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess
Suggested change
would be more performant in the default case? The question here is also if we should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although probably, since we divide gamma by 2 anyway in the computation it doesn't matter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm I wonder whether this is what you want. Generally speaking, if you're using this kernel you probably don't want the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. |
||||||
) where {Tα<:Real, Tγ<:Real} | ||||||
@check_args(GammaRationalQuadraticKernel, α, α > zero(Tα), "α > 0") | ||||||
@check_args(GammaRationalQuadraticKernel, γ, zero(γ) < γ <= 2, "0 < γ <= 2") | ||||||
return new{Tα, Tγ}([α], [γ]) | ||||||
end | ||||||
end | ||||||
|
||||||
@functor GammaRationalQuadraticKernel | ||||||
|
||||||
kappa(κ::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²^first(κ.γ)/first(κ.α))^(-first(κ.α)) | ||||||
function kappa(κ::GammaRationalQuadraticKernel, d²::Real) | ||||||
return (one(d²) + d²^(first(κ.γ) / 2) / first(κ.α))^(-first(κ.α)) | ||||||
end | ||||||
|
||||||
metric(::GammaRationalQuadraticKernel) = SqEuclidean() | ||||||
|
||||||
Base.show(io::IO, κ::GammaRationalQuadraticKernel) = print(io, "Gamma Rational Quadratic Kernel (α = ", first(κ.α), ", γ = ", first(κ.γ), ")") | ||||||
function Base.show(io::IO, κ::GammaRationalQuadraticKernel) | ||||||
print(io, "Gamma Rational Quadratic Kernel (α = $(first(κ.α)), γ = $(first(κ.γ)))") | ||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
module TestUtils | ||
|
||
const __ATOL = 1e-9 | ||
|
||
using LinearAlgebra | ||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using KernelFunctions | ||
using Random | ||
using Test | ||
|
||
""" | ||
test_interface( | ||
k::Kernel, | ||
x0::AbstractVector, | ||
x1::AbstractVector, | ||
x2::AbstractVector; | ||
atol=__ATOL, | ||
) | ||
|
||
Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`. | ||
`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should | ||
be of different lengths. | ||
|
||
test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:AbstractVector}; atol=__ATOL) | ||
|
||
`test_interface` offers certain types of test data generation to make running these tests | ||
require less code for common input types. For example, `Vector{<:Real}`, `ColVecs{<:Real}`, | ||
and `RowVecs{<:Real}` are supported. For other input vector types, please provide the data | ||
manually. | ||
""" | ||
function test_interface( | ||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
k::Kernel, | ||
x0::AbstractVector, | ||
x1::AbstractVector, | ||
x2::AbstractVector; | ||
atol=__ATOL, | ||
) | ||
# TODO: uncomment the tests of ternary kerneldiagmatrix. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be too much but maybe adding some printing for each of the test would give some nice feedback to the user and avoid stalling in Travis? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I'm reluctant to do this in this PR. The user has plenty of feedback when the tests fail in the way that the test sets are printed, and I personally prefer to minimise output during the running of tests. |
||
# Ensure that we have the required inputs. | ||
@assert length(x0) == length(x1) | ||
@assert length(x0) ≠ length(x2) | ||
|
||
# Check that kerneldiagmatrix basically works. | ||
# @test kerneldiagmatrix(k, x0, x1) isa AbstractVector | ||
# @test length(kerneldiagmatrix(k, x0, x1)) == length(x0) | ||
|
||
# Check that pairwise basically works. | ||
@test kernelmatrix(k, x0, x2) isa AbstractMatrix | ||
@test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2)) | ||
|
||
# Check that elementwise is consistent with pairwise. | ||
# @test kerneldiagmatrix(k, x0, x1) ≈ diag(kernelmatrix(k, x0, x1)) atol=atol | ||
|
||
# Check additional binary elementwise properties for kernels. | ||
# @test kerneldiagmatrix(k, x0, x1) ≈ kerneldiagmatrix(k, x1, x0) | ||
@test kernelmatrix(k, x0, x2) ≈ kernelmatrix(k, x2, x0)' atol=atol | ||
|
||
# Check that unary elementwise basically works. | ||
@test kerneldiagmatrix(k, x0) isa AbstractVector | ||
@test length(kerneldiagmatrix(k, x0)) == length(x0) | ||
|
||
# Check that unary pairwise basically works. | ||
@test kernelmatrix(k, x0) isa AbstractMatrix | ||
@test size(kernelmatrix(k, x0)) == (length(x0), length(x0)) | ||
@test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0)' atol=atol | ||
|
||
# Check that unary elementwise is consistent with unary pairwise. | ||
@test kerneldiagmatrix(k, x0) ≈ diag(kernelmatrix(k, x0)) atol=atol | ||
|
||
# Check that unary pairwise produces a positive definite matrix (approximately). | ||
@test eigmin(Matrix(kernelmatrix(k, x0))) > -atol | ||
|
||
# Check that unary elementwise / pairwise are consistent with the binary versions. | ||
# @test kerneldiagmatrix(k, x0) ≈ kerneldiagmatrix(k, x0, x0) atol=atol | ||
@test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0, x0) atol=atol | ||
|
||
# Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`. | ||
@test k(first(x0), first(x1)) isa Real | ||
@test kernelmatrix(k, x0, x2) ≈ [k(xl, xr) for xl in x0, xr in x2] | ||
|
||
tmp = Matrix{Float64}(undef, length(x0), length(x2)) | ||
@test kernelmatrix!(tmp, k, x0, x2) ≈ kernelmatrix(k, x0, x2) | ||
|
||
tmp_square = Matrix{Float64}(undef, length(x0), length(x0)) | ||
@test kernelmatrix!(tmp_square, k, x0) ≈ kernelmatrix(k, x0) | ||
|
||
tmp_diag = Vector{Float64}(undef, length(x0)) | ||
@test kerneldiagmatrix!(tmp_diag, k, x0) ≈ kerneldiagmatrix(k, x0) | ||
end | ||
|
||
function test_interface( | ||
rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs... | ||
) where {T<:Real} | ||
test_interface(k, randn(rng, T, 3), randn(rng, T, 3), randn(rng, T, 2); kwargs...) | ||
end | ||
|
||
function test_interface( | ||
rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs..., | ||
) where {T<:Real} | ||
test_interface( | ||
k, | ||
ColVecs(randn(rng, T, dim_in, 3)), | ||
ColVecs(randn(rng, T, dim_in, 3)), | ||
ColVecs(randn(rng, T, dim_in, 2)); | ||
kwargs..., | ||
) | ||
end | ||
|
||
function test_interface( | ||
rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs..., | ||
) where {T<:Real} | ||
test_interface( | ||
k, | ||
RowVecs(randn(rng, T, 3, dim_in)), | ||
RowVecs(randn(rng, T, 3, dim_in)), | ||
RowVecs(randn(rng, T, 2, dim_in)); | ||
kwargs..., | ||
) | ||
end | ||
|
||
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...) | ||
test_interface(Random.GLOBAL_RNG, k, T; kwargs...) | ||
end | ||
|
||
function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_interface(rng, k, Vector{T}; kwargs...) | ||
test_interface(rng, k, ColVecs{T}; kwargs...) | ||
test_interface(rng, k, RowVecs{T}; kwargs...) | ||
end | ||
|
||
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...) | ||
test_interface(Random.GLOBAL_RNG, k, T; kwargs...) | ||
end | ||
|
||
end # module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it common to use
α = 2
? Something likeα = 1
seems simpler, and is actually used as the default value by e.g. scikit-learn.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm good question. I really hadn't thought too much about this and wasn't overly fussed. I worry that 1 will yield a kernel with very long-range correlations (I'm just thinking about what a students-t with dof 1 looks like -- I think they coincide in this case).
I can't say that I'm overly fussed overall, so happy to change to 1 from 2 if you would prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this, so do whatever you think is more reasonable. Just noticed this when I compared it with the parameter choices in scikit-learn, and was wondering why we use 2 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm going to leave as-is for the sake of this PR. More than happy for this to be changed in a subsequent PR though.