Skip to content

Commit 8038d1f

Browse files
committed
Adapted PR to the new pkg structure
1 parent 30586ea commit 8038d1f

File tree

9 files changed

+35
-15
lines changed

9 files changed

+35
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515

1616
[compat]
1717
Compat = "2.2, 3"
18-
Distances = "0.8"
18+
Distances = "0.8.2"
1919
Requires = "1.0.1"
2020
SpecialFunctions = "0.8, 0.9, 0.10"
2121
StatsBase = "0.32, 0.33"

src/distances/delta.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ end
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
8-
return a==b
8+
return a == b
99
end
1010

11-
@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b)
12-
@inline (dist::Delta)(a::Number,b::Number) = a==b
11+
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
12+
@inline (dist::Delta)(a::Number,b::Number) = a == b

src/distances/dotproduct.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
struct DotProduct <: Distances.PreMetric end
2+
# struct DotProduct <: Distances.UnionSemiMetric end
23

34
@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
45
@boundscheck if length(a) != length(b)
@@ -7,8 +8,6 @@ struct DotProduct <: Distances.PreMetric end
78
return dot(a,b)
89
end
910

10-
# For later convenience once Distances.jl open their API
1111
@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
12-
1312
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
14-
@inline (dist::DotProduct)(a::Number,b::Number) = Distances.eval_op(dist, a, b)
13+
@inline (dist::DotProduct)(a::Number,b::Number) = a * b

src/distances/sinus.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
struct Sinus{T} <: Distances.SemiMetric
2+
# struct Sinus{T} <: Distances.UnionSemiMetric
23
r::Vector{T}
34
end
45

56
Distances.parameters(d::Sinus) = d.r
7+
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
8+
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
9+
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
610

711
@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
812
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
913
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
1014
end
1115
return sum(abs2, sinpi.(a - b) ./ d.r)
1216
end
13-
14-
# For later convenience once Distances.jl open their API
15-
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
16-
17-
@inline (dist::Sinus)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
18-
@inline (dist::Sinus)(a::Number,b::Number) = abs2(sinpi(a - b) / first(dist.r))

src/kernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct PeriodicKernel{T} <: BaseKernel
1313
end
1414
end
1515

16-
PeriodicKernel(dims::Int = 1) = PeriodicKernel(Float64, dims)
16+
PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)
1717

1818
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
1919

test/distances/sinus.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "sinus" begin
2+
A = rand(10)
3+
B = rand(10)
4+
p = rand(10)
5+
d = KernelFunctions.Sinus(p)
6+
@test Distances.parameters(d) == p
7+
@test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p))
8+
@test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
9+
end

test/kernels/periodic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "Periodic Kernel" begin
2+
x = rand()*2; v1 = rand(3); v2 = rand(3);
3+
r = rand(3)
4+
k = PeriodicKernel(r = r)
5+
@test kappa(k, x) exp(-0.5x)
6+
@test k(v1, v2) exp(-0.5 * sum(abs2, sinpi.(v1 - v2) ./ r))
7+
@test k(v1, v2) == k(v2, v1)
8+
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
9+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ using KernelFunctions: metric
4949
@testset "distances" begin
5050
include(joinpath("distances", "dotproduct.jl"))
5151
include(joinpath("distances", "delta.jl"))
52+
include(joinpath("distances", "sinus.jl"))
5253
end
5354

5455
@testset "transform" begin
@@ -71,6 +72,7 @@ using KernelFunctions: metric
7172
include(joinpath("kernels", "kernelproduct.jl"))
7273
include(joinpath("kernels", "kernelsum.jl"))
7374
include(joinpath("kernels", "matern.jl"))
75+
include(joinpath("kernels", "periodic.jl"))
7476
include(joinpath("kernels", "polynomial.jl"))
7577
include(joinpath("kernels", "piecewisepolynomial.jl"))
7678
include(joinpath("kernels", "rationalquad.jl"))

test/trainable.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "trainable" begin
2-
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5
2+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5; r = rand(3)
33

44
kc = ConstantKernel(c=c)
55
@test all(params(kc) .== params([c]))
@@ -22,6 +22,9 @@
2222
kp = PolynomialKernel(c=c, d=d)
2323
@test all(params(kp) .== params([d], [c]))
2424

25+
kpe = PeriodicKernel(r = r)
26+
@test all(params(kpe) .== params(r))
27+
2528
kr = RationalQuadraticKernel=α)
2629
@test all(params(kr) .== params([α]))
2730

0 commit comments

Comments
 (0)