Skip to content

Commit f24f285

Browse files
willtebbuttwtmolet
authored
Implement periodic transform (#173)
* Implement periodic transform * Improve docstring * Flip sin-cos order * Use vectors rather than Refs * Update src/transform/periodic_transform.jl Co-authored-by: Letif Mones <[email protected]> * Bump patch Co-authored-by: wt <[email protected]> Co-authored-by: Letif Mones <[email protected]>
1 parent 744419e commit f24f285

File tree

6 files changed

+61
-13
lines changed

6 files changed

+61
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.0"
3+
version = "0.8.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/KernelFunctions.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export TransformedKernel, ScaledKernel
3838
export TensorProduct
3939

4040
export Transform, SelectTransform, ChainTransform, ScaleTransform, LinearTransform,
41-
ARDTransform, IdentityTransform, FunctionTransform
41+
ARDTransform, IdentityTransform, FunctionTransform, PeriodicTransform
4242

4343
export NystromFact, nystrom
4444

@@ -57,10 +57,6 @@ using StatsFuns: logtwo
5757
using InteractiveUtils: subtypes
5858
using StatsBase
5959

60-
"""
61-
Abstract type defining a slice-wise transformation on an input matrix
62-
"""
63-
abstract type Transform end
6460

6561
abstract type Kernel end
6662
abstract type SimpleKernel <: Kernel end
@@ -70,7 +66,15 @@ include(joinpath("distances", "pairwise.jl"))
7066
include(joinpath("distances", "dotproduct.jl"))
7167
include(joinpath("distances", "delta.jl"))
7268
include(joinpath("distances", "sinus.jl"))
69+
7370
include(joinpath("transform", "transform.jl"))
71+
include(joinpath("transform", "scaletransform.jl"))
72+
include(joinpath("transform", "ardtransform.jl"))
73+
include(joinpath("transform", "lineartransform.jl"))
74+
include(joinpath("transform", "functiontransform.jl"))
75+
include(joinpath("transform", "selecttransform.jl"))
76+
include(joinpath("transform", "chaintransform.jl"))
77+
include(joinpath("transform", "periodic_transform.jl"))
7478

7579
include(joinpath("basekernels", "constant.jl"))
7680
include(joinpath("basekernels", "cosine.jl"))

src/transform/periodic_transform.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
PeriodicTransform(f)
3+
4+
Makes a kernel periodic by mapping a scalar input onto the unit circle. Samples from a GP
5+
with a kernel with this transformation applied will produce samples with frequency `f`.
6+
"""
7+
struct PeriodicTransform{Tf<:AbstractVector{<:Real}} <: Transform
8+
f::Tf
9+
end
10+
11+
@functor PeriodicTransform
12+
13+
PeriodicTransform(f::Real) = PeriodicTransform([f])
14+
15+
dim(t::PeriodicTransform) = 2
16+
17+
(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)]
18+
19+
function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
20+
return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x)))
21+
end
22+
23+
function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform)
24+
return isequal(first(t1.f), first(t2.f))
25+
end
26+
27+
function Base.show(io::IO, t::PeriodicTransform)
28+
print(io, "Periodic Transform with frequency $(first(t.f))")
29+
end

src/transform/transform.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
include("scaletransform.jl")
2-
include("ardtransform.jl")
3-
include("lineartransform.jl")
4-
include("functiontransform.jl")
5-
include("selecttransform.jl")
6-
include("chaintransform.jl")
7-
1+
"""
2+
Abstract type defining a slice-wise transformation on an input matrix
3+
"""
4+
abstract type Transform end
85

96
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
107
_map(t::Transform, x::AbstractVector) = t.(x)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ include("test_utils.jl")
7272
print(" ")
7373
include(joinpath("transform", "chaintransform.jl"))
7474
print(" ")
75+
include(joinpath("transform", "periodic_transform.jl"))
76+
print(" ")
7577
end
7678
@info "Ran tests on Transform"
7779

test/transform/periodic_transform.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@testset "periodic_transform" begin
2+
@testset "compare to periodic exponentiated quadratic" begin
3+
rng = MersenneTwister(123456)
4+
f = rand(rng) + 2.0
5+
x = collect(range(0.0, 3.0 / f; length=1_000))
6+
7+
# Construct in the usual way.
8+
k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f)
9+
10+
# Construct using the peridic transform.
11+
k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f))
12+
13+
@test kernelmatrix(k_eq_periodic, x) kernelmatrix(k_eq_transform, x)
14+
# TODO - add interface_tests once #159 is merged.
15+
end
16+
end

0 commit comments

Comments
 (0)