Skip to content

Commit f1dad6e

Browse files
committed
Replaced sin(\pi) by sinpi and corrected the constructors and some tests
1 parent 53d831b commit f1dad6e

File tree

4 files changed

+40
-23
lines changed

4 files changed

+40
-23
lines changed

src/distances/sinus.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ Distances.parameters(d::Sinus) = d.r
88
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
99
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
1010
end
11-
return sum(abs2,sin.(π.*(a-b))./d.r)
11+
return sum(abs2, sinpi.(a - b) ./ d.r)
1212
end
1313

1414
# For later convenience once Distances.jl open their API
15-
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sin* (a - b)) / p)
15+
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
1616

1717
@inline (dist::Sinus)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
18-
@inline (dist::Sinus)(a::Number,b::Number) = abs2(sin* (a - b)) / first(dist.r))
18+
@inline (dist::Sinus)(a::Number,b::Number) = abs2(sinpi(a - b) / first(dist.r))

src/kernels/periodic.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77
"""
88
struct PeriodicKernel{T} <: BaseKernel
99
r::Vector{T}
10+
function PeriodicKernel(; r::AbstractVector{T} = ones(Float64, 1)) where {T<:Real}
11+
@assert all(r .> 0)
12+
new{T}(r)
13+
end
1014
end
1115

12-
PeriodicKernel(dims::Int) = PeriodicKernel{Float64}(ones(Float64,dims))
16+
PeriodicKernel(dims::Int = 1) = PeriodicKernel(Float64, dims)
17+
18+
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
1319

1420
metric::PeriodicKernel) = Sinus.r)
1521

test/test_flux.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,31 @@ using KernelFunctions
22
using Test
33
using Flux
44

5+
#Alphabetic order
56
@testset "Params" begin
6-
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
7+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; r = 2.0
8+
## Base kernels
79
kc = ConstantKernel(c=c)
810
@test all(params(kc) .== params([c]))
9-
km = MaternKernel=ν)
10-
@test all(params(km) .== params([ν]))
11-
kl = LinearKernel(c=c)
12-
@test all(params(kl) .== params([c]))
11+
1312
kge = GammaExponentialKernel=γ)
1413
@test all(params(kge) .== params([γ]))
14+
1515
kgr = GammaRationalQuadraticKernel=γ, α=α)
1616
@test all(params(kgr) .== params([α], [γ]))
17+
18+
km = MaternKernel=ν)
19+
@test all(params(km) .== params([ν]))
20+
21+
kl = LinearKernel(c=c)
22+
@test all(params(kl) .== params([c]))
23+
24+
kpe = PeriodicKernel(r=[r])
25+
@test all(params(kpe) .== params(r))
26+
1727
kp = PolynomialKernel(c=c, d=d)
1828
@test all(params(kp) .== params([d], [c]))
29+
1930
kr = RationalQuadraticKernel=α)
2031
@test all(params(kr) .== params([α]))
2132

test/test_kernels.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,28 +120,28 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
120120
end
121121
@testset "Periodic Kernel" begin
122122
r = rand(3)
123-
k = PeriodicKernel(r)
123+
k = PeriodicKernel(r = r)
124124
@test kappa(k, x) exp(-0.5x)
125-
@test k(v1, v2) exp(-0.5 * sum(abs2,sin.(π*(v1-v2))./r))
125+
@test k(v1, v2) exp(-0.5 * sum(abs2, sinpi.(v1 - v2) ./ r))
126126
@test k(v1, v2) == k(v2, v1)
127-
127+
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
128128
end
129129
@testset "Transformed/Scaled Kernel" begin
130130
s = rand()
131131
v = rand(3)
132132
k = SqExponentialKernel()
133-
kt = TransformedKernel(k,ScaleTransform(s))
134-
ktard = TransformedKernel(k,ARDTransform(v))
135-
ks = ScaledKernel(k,s)
136-
@test kappa(kt,v1,v2) == kappa(transform(k,ScaleTransform(s)),v1,v2)
137-
@test kappa(kt,v1,v2) == kappa(transform(k,s),v1,v2)
138-
@test kappa(kt,v1,v2) == kappa(k,s*v1,s*v2)
139-
@test kappa(ktard,v1,v2) == kappa(transform(k,ARDTransform(v)),v1,v2)
140-
@test kappa(ktard,v1,v2) == kappa(transform(k,v),v1,v2)
141-
@test kappa(ktard,v1,v2) == kappa(k,v.*v1,v.*v2)
133+
kt = TransformedKernel(k, ScaleTransform(s))
134+
ktard = TransformedKernel(k, ARDTransform(v))
135+
ks = ScaledKernel(k, s)
136+
@test kappa(kt, v1, v2) == kappa(transform(k, ScaleTransform(s)), v1, v2)
137+
@test kappa(kt, v1, v2) == kappa(transform(k, s), v1, v2)
138+
@test kappa(kt, v1, v2) kappa(k, s * v1, s * v2)
139+
@test kappa(ktard, v1, v2) == kappa(transform(k, ARDTransform(v)), v1, v2)
140+
@test kappa(ktard, v1, v2) == kappa(transform(k, v), v1, v2)
141+
@test kappa(ktard, v1, v2) == kappa(k, v .* v1, v .* v2)
142142
@test KernelFunctions.metric(kt) == KernelFunctions.metric(k)
143-
@test kappa(ks,x) == s*kappa(k,x)
144-
@test kappa(ks,x) == kappa(s*k,x)
143+
@test kappa(ks, x) == s * kappa(k, x)
144+
@test kappa(ks, x) == kappa(s * k, x)
145145
end
146146
@testset "KernelCombinations" begin
147147
k1 = LinearKernel()

0 commit comments

Comments
 (0)