Skip to content

Commit 0654f2f

Browse files
sharanrywilltebbuttdevmotion
authored
Add Cosine Kernel (#45)
* Add Cosine Kernel * Update src/kernels/cosine.jl Co-Authored-By: willtebbutt <[email protected]> * Update src/kernels/cosine.jl Co-Authored-By: David Widmann <[email protected]> * Modify Cosine to be parameter-less * Add tests * Update Docstring * Change metric to Euclidean * Use cospi, change scaling factor to pi and update tests. Co-authored-by: willtebbutt <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 2133604 commit 0654f2f

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ include("distances/dotproduct.jl")
4444
include("distances/delta.jl")
4545
include("transform/transform.jl")
4646

47-
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated"]
47+
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine"]
4848
include(joinpath("kernels",k*".jl"))
4949
end
5050
include("kernels/transformedkernel.jl")
@@ -53,7 +53,6 @@ include("matrix/kernelmatrix.jl")
5353
include("kernels/kernelsum.jl")
5454
include("kernels/kernelproduct.jl")
5555
include("approximations/nystrom.jl")
56-
5756
include("generic.jl")
5857

5958
include("zygote_adjoints.jl")

src/kernels/cosine.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
CosineKernel
3+
4+
The cosine kernel is a stationary kernel for a sinusoidal given by
5+
```
6+
κ(x,y) = cos( π * (x-y) )
7+
```
8+
9+
"""
10+
struct CosineKernel <: BaseKernel end
11+
12+
kappa::CosineKernel, d::Real) = cospi(d)
13+
14+
metric(::CosineKernel) = Euclidean()

test/test_kernels.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
2626
@test kappa(k,0.5) == c
2727
end
2828
end
29+
@testset "Cosine" begin
30+
k = CosineKernel()
31+
@test eltype(k) == Any
32+
@test kappa(k, 1.0) -1.0 atol=1e-5
33+
@test kappa(k, 2.0) 1.0 atol=1e-5
34+
@test kappa(k, 1.5) 0.0 atol=1e-5
35+
@test kappa(k,x) cospi(x) atol=1e-5
36+
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
37+
end
2938
@testset "Exponential" begin
3039
@testset "SqExponentialKernel" begin
3140
k = SqExponentialKernel()
@@ -133,8 +142,8 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
133142
ks = ScaledKernel(k,s)
134143
@test kappa(kt,v1,v2) == kappa(transform(k,ScaleTransform(s)),v1,v2)
135144
@test kappa(kt,v1,v2) == kappa(transform(k,s),v1,v2)
136-
@test kappa(kt,v1,v2) == kappa(k,s*v1,s*v2)
137-
@test kappa(ktard,v1,v2) kappa(transform(k,ARDTransform(v)),v1,v2)
145+
@test kappa(kt,v1,v2) kappa(k,s*v1,s*v2) atol=1e-5
146+
@test kappa(ktard,v1,v2) kappa(transform(k,ARDTransform(v)),v1,v2) atol=1e-5
138147
@test kappa(ktard,v1,v2) == kappa(transform(k,v),v1,v2)
139148
@test kappa(ktard,v1,v2) == kappa(k,v.*v1,v.*v2)
140149
@test KernelFunctions.metric(kt) == KernelFunctions.metric(k)

0 commit comments

Comments
 (0)