Skip to content

Commit 46fa287

Browse files
willtebbuttgithub-actions[bot]devmotion
authored
Ensure accepted input types are general where possible (#480)
* Strings with FunctionTransform * String with SelectTransform * Strings with ConstantKernel * Add Vector{String} test util * Constant kernelswith strings * Test kernelproduct with strings * Test kernelsum with strings * Transformed kernel with string input * KernelTensorProduct with String * Bump patch * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Formatting * Formatting suggestions Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent d425983 commit 46fa287

File tree

11 files changed

+55
-10
lines changed

11 files changed

+55
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.46"
3+
version = "0.10.47"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/TestUtils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,28 @@ function test_interface(
142142
)
143143
end
144144

145+
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...)
146+
return test_interface(
147+
k,
148+
[randstring(rng) for _ in 1:3],
149+
[randstring(rng) for _ in 1:3],
150+
[randstring(rng) for _ in 1:4];
151+
kwargs...,
152+
)
153+
end
154+
155+
function test_interface(
156+
rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
157+
)
158+
return test_interface(
159+
k,
160+
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
161+
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]),
162+
ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:4]);
163+
kwargs...,
164+
)
165+
end
166+
145167
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
146168
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
147169
end

src/distances/delta.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# Delta is not following the PreMetric rules since d(x, x) == 1
22
struct Delta <: Distances.UnionPreMetric end
33

4-
(dist::Delta)(a::Number, b::Number) = a == b
5-
Base.@propagate_inbounds function (dist::Delta)(
6-
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
7-
)
4+
(dist::Delta)(a, b) = a == b
5+
6+
Base.@propagate_inbounds function (dist::Delta)(a::AbstractArray, b::AbstractArray)
87
@boundscheck if length(a) != length(b)
98
throw(
109
DimensionMismatch(

src/transform/selecttransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ _maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]
2828
_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
2929
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
3030

31-
_wrap(x::AbstractVector{<:Real}, ::Any) = x
32-
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
31+
_wrap(x::AbstractVector, ::Any) = x
32+
_wrap(X::AbstractMatrix, ::Type{T}) where {T} = T(X)
3333

3434
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")

test/basekernels/constant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# Standardised tests.
1010
TestUtils.test_interface(k, Float64)
11+
TestUtils.test_interface(k, Vector{String})
1112
test_ADs(ZeroKernel)
1213
test_interface_ad_perf(_ -> k, nothing, StableRNG(123456))
1314
end
@@ -22,6 +23,7 @@
2223

2324
# Standardised tests.
2425
TestUtils.test_interface(k, Float64)
26+
TestUtils.test_interface(k, Vector{String})
2527
test_ADs(WhiteKernel)
2628
test_interface_ad_perf(_ -> k, nothing, StableRNG(123456))
2729
end
@@ -38,6 +40,7 @@
3840

3941
# Standardised tests.
4042
TestUtils.test_interface(k, Float64)
43+
TestUtils.test_interface(k, Vector{String})
4144
test_ADs(c -> ConstantKernel(; c=only(c)), [c])
4245
test_interface_ad_perf(c -> ConstantKernel(; c=c), c, StableRNG(123456))
4346
end

test/kernels/kernelproduct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# Standardised tests.
1313
TestUtils.test_interface(k, Float64)
14+
TestUtils.test_interface(ConstantKernel(; c=1.0) * WhiteKernel(), Vector{String})
1415
test_ADs(
1516
x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1)
1617
)

test/kernels/kernelsum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
# Standardised tests.
1414
TestUtils.test_interface(k, Float64)
15+
TestUtils.test_interface(ConstantKernel(; c=1.5) * WhiteKernel(), Vector{String})
1516
test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1))
1617
test_interface_ad_perf(2.4, StableRNG(123456)) do c
1718
KernelSum(SqExponentialKernel(), LinearKernel(; c=c))

test/kernels/kerneltensorproduct.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
# Standardised tests.
3333
TestUtils.test_interface(kernel1, ColVecs{Float64})
3434
TestUtils.test_interface(kernel1, RowVecs{Float64})
35+
TestUtils.test_interface(
36+
KernelTensorProduct(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
37+
)
3538
test_ADs(
3639
x -> KernelTensorProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
3740
rand(1);

test/kernels/transformedkernel.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))
2626

2727
TestUtils.test_interface(k, Float64)
28+
TestUtils.test_interface(
29+
TransformedKernel(ConstantKernel(; c=1.5), FunctionTransform(x -> x * "hi")),
30+
Vector{String},
31+
)
2832
test_ADs(x -> SqExponentialKernel() ScaleTransform(x[1]), rand(1))
2933
test_interface_ad_perf(0.35, StableRNG(123456)) do λ
3034
SqExponentialKernel() ScaleTransform(λ)

test/transform/functiontransform.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
end
2727
end
2828

29+
@testset "String input" begin
30+
f = x -> x * "hello"
31+
t = FunctionTransform(f)
32+
x = [randstring(rng) for _ in 1:3]
33+
y = map(t, x)
34+
@test all([t(x[n]) == y[n] for n in eachindex(x)])
35+
@test all([f(x[n]) == y[n] for n in eachindex(x)])
36+
end
37+
2938
@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
3039
f(a, x) = sin.(a .* x)
3140
test_ADs(x -> SEKernel() FunctionTransform(y -> f(x, y)), randn(rng, 3))

test/transform/selecttransform.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
x_vecs = [randn(rng, maximum(select)) for _ in 1:3]
88
x_cols = ColVecs(randn(rng, maximum(select), 6))
99
x_rows = RowVecs(randn(rng, 4, maximum(select)))
10+
x_string = ColVecs([randstring(rng) for _ in 1:maximum(select), _ in 1:5])
1011

11-
Xs = [x_vecs, x_cols, x_rows]
12+
Xs = [x_vecs, x_cols, x_rows, x_string]
1213

1314
@testset "$(typeof(x))" for x in Xs
1415
x′ = map(t, x)
@@ -24,8 +25,9 @@
2425
a_vecs = map(x -> AxisArray(x; col=symbols), x_vecs)
2526
a_cols = ColVecs(AxisArray(x_cols.X; col=symbols, index=(1:6)))
2627
a_rows = RowVecs(AxisArray(x_rows.X; index=(1:4), col=symbols))
28+
a_string = ColVecs(AxisArray(x_string.X; col=symbols, index=1:5))
2729

28-
As = [a_vecs, a_cols, a_rows]
30+
As = [a_vecs, a_cols, a_rows, a_string]
2931

3032
@testset "$(typeof(a))" for (a, x) in zip(As, Xs)
3133
a′ = map(ts, a)
@@ -122,8 +124,9 @@
122124
("Vector{<:Vector}", [randn(6) for _ in 1:3]),
123125
("ColVecs", ColVecs(randn(5, 10))),
124126
("RowVecs", RowVecs(randn(11, 4))),
127+
("ColVecs{String}", ColVecs([randstring() for _ in 1:6, _ in 1:5])),
125128
]
126-
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
129+
@test KernelFunctions._map(t, x) isa AbstractVector
127130
end
128131
end
129132
end

0 commit comments

Comments
 (0)