Skip to content

Commit 9f708d0

Browse files
authored
Testing (and fixing) handling of AbstractVector{AbstractVector{T}} inputs (#370)
* Fix input types, improve readability * Add missing bit * Add doc string * Fix mistake * Add docstring to docs * Reformulate * Add test * Potential fix for Delta distance * Fix fbm * Further Delta fixes * Formatter * Remove old Delta implementation * Make interface more specific * Formatter * Bump patch
1 parent 979a019 commit 9f708d0

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.22"
3+
version = "0.10.23"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ end
4343
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2
4444

4545
_mod(x::AbstractVector{<:Real}) = abs2.(x)
46+
_mod(x::AbstractVector{<:AbstractVector{<:Real}}) = sum.(abs2, x)
47+
# two lines above could be combined into the second (dispatching on general AbstractVectors), but this (somewhat) more performant
4648
_mod(x::ColVecs) = vec(sum(abs2, x.X; dims=1))
4749
_mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))
4850

src/distances/delta.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
11
# Delta is not following the PreMetric rules since d(x, x) == 1
22
struct Delta <: Distances.UnionPreMetric end
33

4-
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector)
5-
@boundscheck if length(a) != length(b)
6-
throw(
7-
DimensionMismatch(
8-
"first array has length $(length(a)) which does not match the length of the " *
9-
"second, $(length(b)).",
10-
),
11-
)
12-
end
13-
return a == b
14-
end
15-
16-
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
17-
4+
@inline Distances.eval_op(::Delta, a::Real, b::Real) = a == b
5+
@inline Distances.eval_reduce(::Delta, a, b) = a && b
6+
@inline Distances.eval_start(::Delta, a, b) = true
187
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
198
@inline (dist::Delta)(a::Number, b::Number) = a == b
9+
10+
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool

src/test_utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ function test_interface(
133133
)
134134
end
135135

136+
function test_interface(
137+
rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
138+
) where {T<:Real}
139+
return test_interface(
140+
k,
141+
[randn(rng, T, dim_in) for _ in 1:1001],
142+
[randn(rng, T, dim_in) for _ in 1:1001],
143+
[randn(rng, T, dim_in) for _ in 1:1000];
144+
kwargs...,
145+
)
146+
end
147+
136148
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
137149
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
138150
end
@@ -147,6 +159,9 @@ function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...)
147159
@testset "RowVecs{$T}" begin
148160
test_interface(rng, k, RowVecs{T}; kwargs...)
149161
end
162+
@testset "Vector{Vector{T}}" begin
163+
test_interface(rng, k, Vector{Vector{T}}; kwargs...)
164+
end
150165
end
151166

152167
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)

0 commit comments

Comments
 (0)