Skip to content

Commit 1c70c5d

Browse files
committed
Pointed out where tests are needed
1 parent 22f5464 commit 1c70c5d

File tree

10 files changed

+45
-37
lines changed

10 files changed

+45
-37
lines changed

src/generic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
@inline metric::Kernel) = κ.metric
22

33
## Allows to iterate over kernels
4-
Base.length(::Kernel) = 1
4+
Base.length(::Kernel) = 1 #TODO Add test
55

6-
Base.iterate(k::Kernel) = (k,nothing)
7-
Base.iterate(k::Kernel, ::Any) = nothing
6+
Base.iterate(k::Kernel) = (k,nothing) #TODO Add test
7+
Base.iterate(k::Kernel, ::Any) = nothing #TODO Add test
88

99
### Syntactic sugar for creating matrices and using kernel functions
1010
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
1111
@eval begin
12-
@inline::$k)(d::Real) = kappa(κ,d)
12+
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
1313
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate.metric,transform(κ,x),transform(κ,y)))
1414
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
1515
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)

src/kernels/constant.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ZeroKernel()
2+
ZeroKernel([tr=IdentityTransform()])
33
44
Create a kernel that always return a zero kernel matrix
55
@@ -19,7 +19,7 @@ end
1919
@inline kappa::ZeroKernel,d::T) where {T<:Real} = zero(T)
2020

2121
"""
22-
WhiteKernel()
22+
WhiteKernel([tr=IdentityTransform()])
2323
2424
```
2525
κ(x,y) = δ(x,y)
@@ -41,7 +41,7 @@ end
4141
@inline kappa::WhiteKernel,δₓₓ::Real) = δₓₓ
4242

4343
"""
44-
ConstantKernel([c=1.0])
44+
ConstantKernel([tr=IdentityTransform(),[c=1.0]])
4545
4646
```
4747
κ(x,y) = c

src/kernels/exponential.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,6 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
5555
```
5656
κ(x,y) = exp(-‖x-y‖^2γ)
5757
```
58-
59-
# Examples
60-
61-
```jldoctest; setup = :(using KernelFunctions)
62-
julia> GammaExponentialKernel()
63-
GammaExponentialKernel{Float64,Float64,Float64}(1.0,2.0)
64-
65-
julia> GammaExponentialKernel(2.0f0,3.0)
66-
GammaExponentialKernel{Float32,Float32,Float64}(2.0,3.0)
67-
68-
julia> GammaExponentialKernel([2.0,3.0],2f0)
69-
GammaExponentialKernel{Float64,Array{Float64},Float32}([2.0,3.0],2.0)
70-
```
7158
"""
7259
struct GammaExponentialKernel{T,Tr,Tᵧ<:Real} <: Kernel{T,Tr}
7360
transform::Tr

src/kernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ExponentiatedKernel([α=1])
2+
ExponentiatedKernel([ρ=1])
33
44
The exponentiated kernel is a Mercer kernel given by:
55

src/kernels/kernelproduct.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
KernelProduct(kernels::Array{Kernel})
3+
Create a multiplication of kernels.
4+
One can also use the operator `*`
5+
```
6+
kernelmatrix(SqExponentialKernel()*LinearKernel(),X) == kernelmatrix(SqExponentialKernel(),X).*kernelmatrix(LinearKernel(),X)
7+
```
8+
"""
19
struct KernelProduct{T,Tr} <: Kernel{T,Tr}
210
kernels::Vector{Kernel}
311
end
@@ -7,14 +15,15 @@ function KernelProduct(kernels::AbstractVector{<:Kernel})
715
end
816

917
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
18+
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
1019
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
1120
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k))
1221

1322
Base.length(k::KernelProduct) = length(k.kernels)
14-
metric(k::KernelProduct) = getmetric.(k.kernels)
15-
transform(k::KernelProduct) = transform.(k.kernels)
16-
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x])
17-
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim)
23+
metric(k::KernelProduct) = getmetric.(k.kernels) #TODO Add test
24+
transform(k::KernelProduct) = transform.(k.kernels) #TODO Add test
25+
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x]) #TODO Add test
26+
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim) #TODO Add test
1827

1928
hadamard(x,y) = x.*y
2029

@@ -36,6 +45,6 @@ end
3645
function kerneldiagmatrix(
3746
κ::KernelProduct,
3847
X::AbstractMatrix;
39-
obsdim::Int=defaultobs)
48+
obsdim::Int=defaultobs) #TODO Add test
4049
reduce(hadamard,kerneldiagmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
4150
end

src/kernels/kernelsum.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
KernelSum(kernels::Array{Kernel};weights::Array{Real}=ones(length(kernels)))
3+
Create a positive weighted sum of kernels.
4+
One can also use the operator `+`
5+
```
6+
kernelmatrix(SqExponentialKernel()+LinearKernel(),X) == kernelmatrix(SqExponentialKernel(),X).+kernelmatrix(LinearKernel(),X)
7+
```
8+
"""
19
struct KernelSum{T,Tr} <: Kernel{T,Tr}
210
kernels::Vector{Kernel}
311
weights::Vector{Real}
@@ -14,8 +22,11 @@ function KernelSum(kernels::AbstractVector{<:Kernel}; weights::AbstractVector{<:
1422
end
1523

1624
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
25+
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))
1726
Base.:+(k::Kernel,ks::KernelSum) = KernelSum(vcat(k,ks.kernels),weights=vcat(1.0,ks.weights))
1827
Base.:+(ks::KernelSum,k::Kernel) = KernelSum(vcat(ks.kernels,k),weights=vcat(ks.weights,1.0))
28+
Base.:*(w::Real,k::Kernel) = KernelSum([k],[w]) #TODO add tests
29+
1930

2031
Base.length(k::KernelSum) = length(k.kernels)
2132
metric(k::KernelSum) = metric.(k.kernels)

src/kernels/matern.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-lgamma.ν) + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
3939

4040
"""
41-
Matern32Kernel(ρ=1.0)
41+
Matern32Kernel([ρ=1.0])
4242
4343
The matern 3/2 kernel is an isotropic Mercer kernel given by the formula:
4444
@@ -59,7 +59,7 @@ end
5959
@inline kappa::Matern32Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
6060

6161
"""
62-
Matern52Kernel(ρ=1.0)
62+
Matern52Kernel([ρ=1.0])
6363
6464
The matern 5/2 kernel is an isotropic Mercer kernel given by the formula:
6565

src/transform/lowranktransform.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ struct LowRankTransform{T<:AbstractMatrix{<:Real}} <: Transform
1212
end
1313

1414
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
15-
Base.size(tr::LowRankTransform) = size(tr.proj)
15+
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test
1616

1717
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs)
1818
@boundscheck size(t,2) != size(X,feature_dim(obsdim)) ?
1919
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))")) : nothing
2020
@inbounds _transform(t,X,obsdim)
2121
end
22-
function transform(t::LowRankTransform,x::AbstractVector{<:Real})
23-
@assert size(t,2) == length(x) "Vector has wrong dimensions"
22+
23+
function transform(t::LowRankTransform,x::AbstractVector{<:Real},obsdim::Int=defaultobs) #TODO Add test
24+
@assert size(t,2) == length(x) "Vector has wrong dimensions $(length(x)) compared to projection matrix"
2425
t.proj*X
2526
end
2627

src/transform/scaletransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1818
ScaleTransform{T}(s)
1919
end
2020

21-
function ScaleTransform(s::T,dims::Integer) where {T<:Real}
21+
function ScaleTransform(s::T,dims::Integer) where {T<:Real} # TODO Add test
2222
@check_args(ScaleTransform, s, s > zero(T), "s > 0")
2323
ScaleTransform{Vector{T}}(fill(s,dims))
2424
end
@@ -28,12 +28,12 @@ function ScaleTransform(s::A) where {A<:AbstractVector{<:Real}}
2828
ScaleTransform{A}(s)
2929
end
3030

31-
dim(str::ScaleTransform{<:Real}) = 1
31+
dim(str::ScaleTransform{<:Real}) = 1 #TODO Add test
3232
dim(str::ScaleTransform{<:AbstractVector{<:Real}}) = length(str.s)
3333

3434
function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int)
3535
@boundscheck if dim(t) != size(X,!Bool(obsdim-1)+1)
36-
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t))."))
36+
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t)).")) #TODO Add test
3737
end
3838
_transform(t,X,obsdim)
3939
end

src/transform/transform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct ChainTransform <: Transform
2222
transforms::Vector{Transform}
2323
end
2424

25-
Base.length(t::ChainTransform) = length(t.transforms)
25+
Base.length(t::ChainTransform) = length(t.transforms) #TODO Add test
2626

2727
function ChainTransform(v::AbstractVector{<:Transform})
2828
ChainTransform(v)
@@ -37,7 +37,7 @@ function transform(t::ChainTransform,X::T,obsdim::Int=defaultobs) where {T}
3737
end
3838

3939
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
40-
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t))
40+
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test
4141
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
4242
"""
4343
IdentityTransform
@@ -46,7 +46,7 @@ Base.:∘(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms
4646
"""
4747
struct IdentityTransform <: Transform end
4848

49-
transform(t::IdentityTransform,x::AbstractArray,obsdim::Int=defaultobs) = x
49+
transform(t::IdentityTransform,x::AbstractArray,obsdim::Int=defaultobs) = x #TODO add test
5050

5151
### TODO Maybe defining adjoints could help but so far it's not working
5252

0 commit comments

Comments
 (0)