Skip to content

Commit 89baf9c

Browse files
authored
Merge pull request JuliaGaussianProcesses#85 from theogf/format_and_tests
Improved docs, printing and testing
2 parents bb3e859 + 93f11c2 commit 89baf9c

37 files changed

+129
-56
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ export duplicate, set! # Helpers
1010

1111
export Kernel
1212
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel
13-
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
13+
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
14+
export LaplacianKernel, ExponentialKernel, GammaExponentialKernel
1415
export ExponentiatedKernel
1516
export MaternKernel, Matern32Kernel, Matern52Kernel
1617
export LinearKernel, PolynomialKernel

src/basekernels/constant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ kappa(κ::ConstantKernel,x::Real) = first(κ.c)*one(x)
5959

6060
metric(::ConstantKernel) = Delta()
6161

62-
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = $(first.c)))")
62+
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first.c), ")")

src/basekernels/exponential.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ kappa(κ::GammaExponentialKernel, d²::Real) = exp(-d²^first(κ.γ))
6363
iskroncompatible(::GammaExponentialKernel) = true
6464
metric(::GammaExponentialKernel) = SqEuclidean()
6565

66-
Base.show(io::IO, κ::GammaExponentialKernel) = print(io, "Gamma Exponential Kernel (γ = $(first.γ)))")
66+
Base.show(io::IO, κ::GammaExponentialKernel) = print(io, "Gamma Exponential Kernel (γ = ", first.γ), ")")

src/basekernels/fbm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct FBMKernel{T<:Real} <: BaseKernel
1717
end
1818
end
1919

20-
Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = $(first(k.h)))")
20+
Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = ", first(κ.h), ")")
2121

2222
const sqroundoff = 1e-15
2323

src/basekernels/gabor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function _gabor(; ell = nothing, p = nothing)
3030
end
3131

3232
function Base.getproperty(k::GaborKernel, v::Symbol)
33-
if v == :kernel
33+
if v == :kernel
3434
return getfield(k, v)
3535
elseif v == :ell
3636
kernel1 = k.kernel.kernels[1]
@@ -51,7 +51,7 @@ function Base.getproperty(k::GaborKernel, v::Symbol)
5151
end
5252
end
5353

54-
Base.show(io::IO, κ::GaborKernel) = print(io, "Gabor Kernel (ell = $(κ.ell), p = $(κ.p))")
54+
Base.show(io::IO, κ::GaborKernel) = print(io, "Gabor Kernel (ell = ", κ.ell, ", p = ", κ.p, ")")
5555

5656
kappa::GaborKernel, x, y) = kappa.kernel, x ,y)
5757

src/basekernels/maha.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ end
1919
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
2020
metric::MahalanobisKernel) = SqMahalanobis.P)
2121

22-
Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = $(size.P))")
22+
Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size.P), ")")

src/basekernels/matern.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727

2828
metric(::MaternKernel) = Euclidean()
2929

30-
Base.show(io::IO, κ::MaternKernel) = print(io, "Matern Kernel (ν = $(first.ν)))")
30+
Base.show(io::IO, κ::MaternKernel) = print(io, "Matern Kernel (ν = ", first.ν), ")")
3131

3232
"""
3333
Matern32Kernel()

src/basekernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ metric(κ::PeriodicKernel) = Sinus(κ.r)
2424

2525
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)
2626

27-
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = $(length.r))")
27+
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel, length(r) = ", length.r), ")")

src/basekernels/piecewisepolynomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,5 @@ end
105105
metric::PiecewisePolynomialKernel) = Mahalanobis.maha)
106106

107107
function Base.show(io::IO, κ::PiecewisePolynomialKernel{V}) where {V}
108-
print(io, "Piecewise Polynomial Kernel (v = $(V), size(maha) = $(size.maha))")
108+
print(io, "Piecewise Polynomial Kernel (v = ", V, ", size(maha) = ", size.maha), ")")
109109
end

src/basekernels/polynomial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
1818
metric(::LinearKernel) = DotProduct()
1919

20-
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = $(first.c)))")
20+
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first.c), ")")
2121

2222
"""
2323
PolynomialKernel(; d = 2.0, c = 0.0)
@@ -40,4 +40,4 @@ end
4040
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first.c))^(first.d))
4141
metric(::PolynomialKernel) = DotProduct()
4242

43-
Base.show(io::IO, κ::PolynomialKernel) = print(io, "Polynomial Kernel (c = $(first.c)), d = $(first.d)))")
43+
Base.show(io::IO, κ::PolynomialKernel) = print(io, "Polynomial Kernel (c = ", first.c), ", d = ", first.d), ")")

src/basekernels/rationalquad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
1919
metric(::RationalQuadraticKernel) = SqEuclidean()
2020

21-
Base.show(io::IO, κ::RationalQuadraticKernel) = print(io, "Rational Quadratic Kernel (α = $(first.α)))")
21+
Base.show(io::IO, κ::RationalQuadraticKernel) = print(io, "Rational Quadratic Kernel (α = ", first.α), ")")
2222

2323
"""
2424
`GammaRationalQuadraticKernel([ρ=1.0[,α=2.0[,γ=2.0]]])`
@@ -41,4 +41,4 @@ end
4141
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
4242
metric(::GammaRationalQuadraticKernel) = SqEuclidean()
4343

44-
Base.show(io::IO, κ::GammaRationalQuadraticKernel) = print(io, "Gamma Rational Quadratic Kernel (α = $(first.α)), γ = $(first.γ)))")
44+
Base.show(io::IO, κ::GammaRationalQuadraticKernel) = print(io, "Gamma Rational Quadratic Kernel (α = ", first.α), ", γ = ", first.γ), ")")

src/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ _scale(t::ScaleTransform, metric::Euclidean, x, y) = first(t.s) * evaluate(metr
1111
_scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) = first(t.s)^2 * evaluate(metric, x, y)
1212
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t, y))
1313

14-
printshifted(io::IO,κ::Kernel,shift::Int) = print(io,"")
15-
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
14+
printshifted(io::IO, o, shift::Int) = print(io, o)
15+
Base.show(io::IO, κ::Kernel) = print(io, nameof(typeof(κ)))
1616

1717
### Syntactic sugar for creating matrices and using kernel functions
1818
function concretetypes(k, ktypes::Vector)

src/transform/ardtransform.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,52 @@
11
"""
2-
ARD Transform
2+
ARDTransform(v::AbstractVector)
3+
ARDTransform(s::Real, dims::Int)
4+
5+
Multiply every vector of observation by `v` element-wise
36
```
47
v = rand(3)
58
tr = ARDTransform(v)
69
```
7-
Multiply every vector of observation by `v` element-wise
810
"""
911
struct ARDTransform{T,N} <: Transform
1012
v::Vector{T}
1113
end
1214

13-
function ARDTransform(s::T,dims::Integer) where {T<:Real}
15+
function ARDTransform(s::T, dims::Integer) where {T<:Real}
1416
@check_args(ARDTransform, s, s > zero(T), "s > 0")
15-
ARDTransform{T,dims}(fill(s,dims))
17+
ARDTransform{T,dims}(fill(s, dims))
1618
end
1719

1820
function ARDTransform(v::AbstractVector{T}) where {T<:Real}
19-
@check_args(ARDTransform, v, all(v.>zero(T)), "v > 0")
21+
@check_args(ARDTransform, v, all(v .> zero(T)), "v > 0")
2022
ARDTransform{T,length(v)}(v)
2123
end
2224

23-
function set!(t::ARDTransform{T}::AbstractVector{T}) where {T<:Real}
25+
function set!(t::ARDTransform{T}, ρ::AbstractVector{T}) where {T<:Real}
2426
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))"
2527
t.v .= ρ
2628
end
2729

2830
dim(t::ARDTransform) = length(t.v)
2931

30-
function apply(t::ARDTransform,X::AbstractMatrix{<:Real};obsdim::Int = defaultobs)
31-
@boundscheck if dim(t) != size(X,feature_dim(obsdim))
32+
function apply(
33+
t::ARDTransform,
34+
X::AbstractMatrix{<:Real};
35+
obsdim::Int = defaultobs,
36+
)
37+
@boundscheck if dim(t) != size(X, feature_dim(obsdim))
3238
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t)).")) #TODO Add test
3339
end
34-
_transform(t,X,obsdim)
40+
_transform(t, X, obsdim)
3541
end
36-
apply(t::ARDTransform,x::AbstractVector{<:Real};obsdim::Int=defaultobs) = t.v .* x
37-
_transform(t::ARDTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.v'.*X : t.v .* X
42+
apply(t::ARDTransform, x::AbstractVector{<:Real}; obsdim::Int = defaultobs) = t.v .* x
43+
_transform(
44+
t::ARDTransform,
45+
X::AbstractMatrix{<:Real},
46+
obsdim::Int = defaultobs,
47+
) = obsdim == 1 ? t.v' .* X : t.v .* X
48+
49+
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
3850

39-
Base.isequal(t::ARDTransform,t2::ARDTransform) = isequal(t.v,t2.v)
51+
Base.show(io::IO, t::ARDTransform) =
52+
print(io, "ARD Transform (dims: ", dim(t),")")

src/transform/chaintransform.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
"""
2+
ChainTransform(ts::AbstractVector{<:Transform})
3+
24
Chain a series of transform, here `t1` will be called first
35
```
46
t1 = ScaleTransform()
57
t2 = LowRankTransform(rand(3,4))
68
ct = ChainTransform([t1,t2]) #t1 will be called first
7-
ct == t2t1
9+
ct == t2t1
810
```
911
"""
10-
struct ChainTransform <: Transform
11-
transforms::Vector{Transform}
12+
struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform
13+
transforms::V
1214
end
1315

1416
Base.length(t::ChainTransform) = length(t.transforms) #TODO Add test
1517

16-
function ChainTransform(v::AbstractVector{<:Transform})
17-
ChainTransform(v)
18-
end
19-
2018
## Constructor to create a chain transform with an array of parameters
2119
function ChainTransform(v::AbstractVector{<:Type{<:Transform}}::AbstractVector)
2220
@assert length(v) == length(θ)
@@ -34,6 +32,23 @@ end
3432
set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
3533
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
3634

37-
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
38-
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test
39-
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
35+
Base.:(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁])
36+
Base.:(t::Transform, tc::ChainTransform) =
37+
ChainTransform(vcat(tc.transforms, t)) #TODO add test
38+
Base.:(tc::ChainTransform, t::Transform) =
39+
ChainTransform(vcat(t, tc.transforms))
40+
41+
Base.show(io::IO, t::ChainTransform) = printshifted(io, t, 0)
42+
43+
function printshifted(io::IO, t::ChainTransform, shift::Int)
44+
println(io, "Chain of ", length(t), " transforms:")
45+
for _ in 1:(shift + 1)
46+
print(io, "\t")
47+
end
48+
print(io, " - ")
49+
printshifted(io, t.transforms[1], shift + 2)
50+
for i in 2:length(t)
51+
print(io, " |> ")
52+
printshifted(io, t.transforms[i], shift + 2)
53+
end
54+
end

src/transform/functiontransform.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
2-
FunctionTransform
2+
FunctionTransform(f)
3+
4+
Take a function or object `f` as an argument which is going to act on each vector individually.
5+
Make sure that `f` is supposed to act on a vector.
6+
For example replace `f(x)=sin(x)` by `f(x)=sin.(x)`
37
```
48
f(x) = abs.(x)
59
tr = FunctionTransform(f)
610
```
7-
Take a function or object `f` as an argument which is going to act on each vector individually.
8-
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
9-
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
1011
"""
1112
struct FunctionTransform{F} <: Transform
1213
f::F
@@ -15,3 +16,5 @@ end
1516
apply(t::FunctionTransform, X::T; obsdim::Int = defaultobs) where {T} = mapslices(t.f, X, dims = feature_dim(obsdim))
1617

1718
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
19+
20+
Base.show(io::IO, t::FunctionTransform) = print(io, "Function Transform: ", t.f)

src/transform/lowranktransform.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""
2-
LowRankTransform
2+
LowRankTransform(P::AbstractMatrix)
3+
4+
Apply the low-rank projection realised by the matrix `P`
5+
The second dimension of `P` must match the number of features of the target.
36
```
47
P = rand(10,5)
58
tr = LowRankTransform(P)
69
```
7-
Apply the low-rank projection realised by the matrix `P`
8-
The second dimension of `P` must match the number of features of the target.
910
"""
1011
struct LowRankTransform{T<:AbstractMatrix{<:Real}} <: Transform
1112
proj::T
@@ -32,3 +33,5 @@ function apply(t::LowRankTransform, x::AbstractVector{<:Real}; obsdim::Int = def
3233
end
3334

3435
_transform(t::LowRankTransform,X::AbstractVecOrMat{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? t.proj * X : X * t.proj'
36+
37+
Base.show(io::IO, t::LowRankTransform) = print(io::IO, "Low Rank Transform (size(P) = ", size(t.proj), ")")

src/transform/scaletransform.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""
2-
Scale Transform
2+
ScaleTransform(l::Real)
3+
4+
Multiply every element of the input by `l`
35
```
46
l = 2.0
57
tr = ScaleTransform(l)
68
```
7-
Multiply every element of the input by `l`
89
"""
910
struct ScaleTransform{T<:Real} <: Transform
1011
s::Vector{T}
@@ -22,4 +23,4 @@ apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s)
2223

2324
Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))
2425

25-
Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform s=$(first(t.s))")
26+
Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform (s = ", first(t.s), ")")

src/transform/selecttransform.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
2-
SelectTransform
2+
SelectTransform(dims::AbstractVector{Int})
3+
4+
Select the dimensions `dims` that the kernel is applied to.
35
```
46
dims = [1,3,5,6,7]
57
tr = SelectTransform(dims)
68
X = rand(100,10)
79
transform(tr,X,obsdim=2) == X[dims,:]
810
```
9-
Select the dimensions `dims` that the kernel is applied to.
1011
"""
1112
struct SelectTransform{T<:AbstractVector{<:Int}} <: Transform
1213
select::T
@@ -38,3 +39,5 @@ function apply(t::SelectTransform, x::AbstractVector{<:Real}; obsdim::Int = defa
3839
end
3940

4041
_transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? view(X,t.select,:) : view(X,:,t.select)
42+
43+
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")

src/transform/transform.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ include("selecttransform.jl")
88
include("chaintransform.jl")
99

1010
"""
11-
`apply(t::Transform, x; obsdim::Int=defaultobs)`
12-
Apply the transform `t` per slice on the array `x`
11+
apply(t::Transform, x; obsdim::Int=defaultobs)
12+
13+
Apply the transform `t` vector-wise on the array `x`
1314
"""
1415
apply
1516

1617
"""
17-
IdentityTransform
18+
IdentityTransform()
19+
1820
Return exactly the input
1921
"""
2022
struct IdentityTransform <: Transform end

test/basekernels/constant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
@test eltype(k) == Any
55
@test kappa(k,2.0) == 0.0
66
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
7+
@test repr(k) == "Zero Kernel"
78
end
89
@testset "WhiteKernel" begin
910
k = WhiteKernel()
@@ -12,6 +13,7 @@
1213
@test kappa(k,0.0) == 0.0
1314
@test EyeKernel == WhiteKernel
1415
@test metric(WhiteKernel()) == KernelFunctions.Delta()
16+
@test repr(k) == "White Kernel"
1517
end
1618
@testset "ConstantKernel" begin
1719
c = 2.0
@@ -21,5 +23,6 @@
2123
@test kappa(k,0.5) == c
2224
@test metric(ConstantKernel()) == KernelFunctions.Delta()
2325
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
26+
@test repr(k) == "Constant Kernel (c = $(c))"
2427
end
2528
end

test/basekernels/cosine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
@test kappa(k, 1.5) 0.0 atol=1e-5
1212
@test kappa(k,x) cospi(x) atol=1e-5
1313
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
14+
@test repr(k) == "Cosine Kernel"
1415
end

0 commit comments

Comments
 (0)