Skip to content

Commit a5bcc63

Browse files
authored
Merge pull request #32 from theogf/remove-transform
Removing transform field and creating TransformedKernel (and ScaledKernel)
2 parents f63be92 + ee6fc0a commit a5bcc63

33 files changed

+437
-407
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*.json
2+
*.cov
23
Manifest.toml
34
coverage/

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.4"
55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -13,7 +14,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1314
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1415

1516
[compat]
16-
Compat = "2.2, 3.2"
17+
Compat = "2.2, 3"
1718
Distances = "0.8"
1819
Requires = "1.0.1"
1920
SpecialFunctions = "0.8, 0.9, 0.10"

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ The aim is to make the API as model-agnostic as possible while still being user-
1313
```julia
1414
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
1515
# Set simple scaling of the data
16-
k₁ = SqExponentialKernel(1.0)
16+
k₁ = SqExponentialKernel()
1717
K₁ = kernelmatrix(k₁,X,obsdim=1)
1818

1919
# Set a function transformation on the data
20-
k₂ = MaternKernel(FunctionTransform(x->sin.(x)))
20+
k₂ = TransformedKernel(Matern32Kernel(),FunctionTransform(x->sin.(x)))
2121
K₂ = kernelmatrix(k₂,X,obsdim=1)
2222

2323
# Set a matrix premultiplication on the data
24-
k₃ = PolynomialKernel(LowRankTransform(randn(4,1)),2.0,0.0)
24+
k₃ = transform(PolynomialKernel(c=2.0,d=2.0),LowRankTransform(randn(4,1)))
2525
K₃ = kernelmatrix(k₃,X,obsdim=1)
2626

2727
# Add and sum kernels
28-
k₄ = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*k₂
28+
k₄ = 0.5*SqExponentialKernel()*LinearKernel(c=0.5) + 0.4*k₂
2929
K₄ = kernelmatrix(k₄,X,obsdim=1)
3030

3131
plot(heatmap.([K₁,K₂,K₃,K₄],yflip=true,colorbar=false)...,layout=(2,2),title=["K₁" "K₂" "K₃" "K₄"])

docs/src/api.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ CurrentModule = KernelFunctions
1414
KernelFunctions
1515
```
1616

17-
## Kernel Functions
17+
## Base Kernels
1818

1919
```@docs
2020
SqExponentialKernel
@@ -33,9 +33,11 @@ ConstantKernel
3333
WhiteKernel
3434
```
3535

36-
## Kernel Combinations
36+
## Composite Kernels
3737

3838
```@docs
39+
TransformedKernel
40+
ScaledKernel
3941
KernelSum
4042
KernelProduct
4143
```

docs/src/kernels.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CurrentModule = KernelFunctions
33
```
44

5+
# Base Kernels
6+
7+
These are the basic kernels without any transformation of the data. They are the building blocks of KernelFunctions
8+
59
## Exponential Kernels
610

711
### Exponential Kernel
@@ -13,7 +17,7 @@ The [Exponential Kernel](@ref ExponentialKernel) is defined as
1317

1418
### Square Exponential Kernel
1519

16-
The [Square Exponential Kernel](@ref KernelFunctions.SqExponentialKernel) is defined as
20+
The [Square Exponential Kernel](@ref KernelFunctions.SqExponentialKernel) is defined as
1721
```math
1822
k(x,x') = \exp\left(-\|x-x'\|^2\right)
1923
```
@@ -91,3 +95,13 @@ The [Square Exponential Kernel](@ref KernelFunctions.SqExponentialKernel) is def
9195
```math
9296
k(x,x') = 0
9397
```
98+
99+
# Composite Kernels
100+
101+
## TransformedKernel
102+
103+
## ScaledKernel
104+
105+
## KernelSum
106+
107+
## KernelProduct

docs/src/metrics.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
KernelFunctions.jl relies on [Distances.jl]() for computing the pairwise matrix.
44
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
5-
However all kernels do not rely on distances metrics respecting all the definitions. That's why two additional metrics come with the package : `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`). If you want to create a new distance just implement the following :
5+
However all kernels do not rely on distances metrics respecting all the definitions. That's why two additional metrics come with the package : `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`).
6+
Note that all base kernels must have a defined metric defined as :
7+
```julia
8+
metric(::CustomKernel) = SqEuclidean()
9+
```
10+
11+
## Adding a new metric
12+
13+
If you want to create a new distance just implement the following :
614

715
```julia
816
struct Delta <: Distances.PreMetric

docs/src/transform.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Transform
22

3-
`Transform` is the object that takes care of transforming the input data before distances are being computed. It can be as standard as `IdentityTransform` returning the same input, can be a scalar with `ScaleTransform` multiplying the vectors by a scalar or a vector.
3+
`Transform` is the object that takes care of transforming the input data before distances are being computed. It can be as standard as `IdentityTransform` returning the same input, or multiplying the data by a scalar with `ScaleTransform` or by a vector with `ARDTransform`.
44
There is a more general `Transform`: `FunctionTransform` that uses a function and apply it on each vector via `mapslices`.
55
You can also create a pipeline of `Transform` via `TransformChain`. For example `LowRankTransform(rand(10,5))∘ScaleTransform(2.0)`.
66

7-
One apply a transformation on a matrix or a vector via `transform(t::Transform,v::AbstractVecOrMat)`
7+
One apply a transformation on a matrix or a vector via `KernelFunctions.apply(t::Transform,v::AbstractVecOrMat)`
88

99
## Transforms :
1010
```@meta
@@ -14,6 +14,7 @@ CurrentModule = KernelFunctions
1414
```@docs
1515
IdentityTransform
1616
ScaleTransform
17+
ARDTransform
1718
LowRankTransform
1819
FunctionTransform
1920
ChainTransform

src/KernelFunctions.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa # Main matrix functions
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4+
export transform
45
export params, duplicate, set! # Helpers
56

67
export Kernel
@@ -11,6 +12,7 @@ export MaternKernel, Matern32Kernel, Matern52Kernel
1112
export LinearKernel, PolynomialKernel
1213
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1314
export KernelSum, KernelProduct
15+
export TransformedKernel, ScaledKernel
1416

1517
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1618

@@ -22,6 +24,7 @@ using Distances, LinearAlgebra
2224
using SpecialFunctions: logabsgamma, besselk
2325
using ZygoteRules: @adjoint
2426
using StatsFuns: logtwo
27+
using InteractiveUtils: subtypes
2528
using StatsBase
2629

2730
const defaultobs = 2
@@ -30,7 +33,8 @@ const defaultobs = 2
3033
Abstract type defining a slice-wise transformation on an input matrix
3134
"""
3235
abstract type Transform end
33-
abstract type Kernel{Tr<:Transform} end
36+
abstract type Kernel end
37+
abstract type BaseKernel <: Kernel end
3438

3539
include("utils.jl")
3640
include("distances/dotproduct.jl")
@@ -40,6 +44,8 @@ include("transform/transform.jl")
4044
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated"]
4145
include(joinpath("kernels",k*".jl"))
4246
end
47+
include("kernels/transformedkernel.jl")
48+
include("kernels/scaledkernel.jl")
4349
include("matrix/kernelmatrix.jl")
4450
include("kernels/kernelsum.jl")
4551
include("kernels/kernelproduct.jl")

src/generic.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,32 @@
1-
@inline metric::Kernel) = κ.metric
2-
31
## Allows to iterate over kernels
42
Base.length(::Kernel) = 1
53
Base.iterate(k::Kernel) = (k,nothing)
64
Base.iterate(k::Kernel, ::Any) = nothing
75

86
# default fallback for evaluating a kernel with two arguments (such as vectors etc)
9-
kappa::Kernel, x, y) = kappa(κ, evaluate(metric(κ), transform(κ, x), transform(κ, y)))
7+
kappa::Kernel, x, y) = kappa(κ, evaluate(metric(κ), x, y))
8+
kappa::TransformedKernel, x, y) = kappa(kernel(κ), apply.transform,x), apply.transform,y))
9+
kappa::TransformedKernel{<:BaseKernel,<:ScaleTransform}, x, y) = kappa(κ, _scale.transform, metric(κ), x, y))
10+
_scale(t::ScaleTransform, metric::Euclidean, x, y) = first(t.s) * evaluate(metric, x, y)
11+
_scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) = first(t.s)^2 * evaluate(metric, x, y)
12+
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t, y))
13+
14+
printshifted(io::IO::Kernel,shift::Int) = print(io,"")
15+
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
1016

1117
### Syntactic sugar for creating matrices and using kernel functions
12-
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
18+
for k in subtypes(BaseKernel)
1319
@eval begin
1420
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
1521
@inline::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
16-
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
17-
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)
22+
@inline::$k)(X::AbstractMatrix{T}, Y::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, Y, obsdim=obsdim)
23+
@inline::$k)(X::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, obsdim=obsdim)
1824
end
1925
end
2026

21-
### Transform generics
22-
@inline transform::Kernel) = κ.transform
23-
@inline transform::Kernel, x) = transform(transform(κ), x)
24-
@inline transform::Kernel, x, obsdim::Int) = transform(transform(κ), x, obsdim)
25-
26-
## Constructors for kernels without parameters
27-
for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52Kernel,:ExponentiatedKernel]
27+
for k in nameof.(subtypes(BaseKernel))
2828
@eval begin
29-
$kernel() = $kernel(IdentityTransform())
30-
$kernel::Real) = $kernel(ScaleTransform(ρ))
31-
$kernel::AbstractVector{<:Real}) = $kernel(ARDTransform(ρ))
29+
@deprecate($k::Real;args...),transform($k(args...),ρ))
30+
@deprecate($k::AbstractVector{<:Real};args...),transform($k(args...),ρ))
3231
end
3332
end

src/kernels/constant.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,49 @@
11
"""
2-
ZeroKernel([tr=IdentityTransform()])
2+
ZeroKernel()
33
4-
Create a kernel always returning zero
4+
Create a kernel that always returning zero
5+
```
6+
κ(x,y) = 0.0
7+
```
8+
The output type depends of `x` and `y`
59
"""
6-
struct ZeroKernel{Tr} <: Kernel{Tr}
7-
transform::Tr
8-
end
9-
10-
ZeroKernel() = ZeroKernel(IdentityTransform())
10+
struct ZeroKernel <: BaseKernel end
1111

12-
@inline kappa::ZeroKernel, d::T) where {T<:Real} = zero(T)
12+
kappa::ZeroKernel, d::T) where {T<:Real} = zero(T)
1313

1414
metric(::ZeroKernel) = Delta()
1515

1616
"""
17-
`WhiteKernel([tr=IdentityTransform()])`
17+
`WhiteKernel()`
1818
1919
```
2020
κ(x,y) = δ(x,y)
2121
```
2222
Kernel function working as an equivalent to add white noise.
2323
"""
24-
struct WhiteKernel{Tr} <: Kernel{Tr}
25-
transform::Tr
26-
end
27-
28-
WhiteKernel() = WhiteKernel(IdentityTransform())
24+
struct WhiteKernel <: BaseKernel end
2925

30-
@inline kappa::WhiteKernel,δₓₓ::Real) = δₓₓ
26+
kappa::WhiteKernel,δₓₓ::Real) = δₓₓ
3127

3228
metric(::WhiteKernel) = Delta()
3329

3430
"""
35-
`ConstantKernel([tr=IdentityTransform(),[c=1.0]])`
31+
`ConstantKernel(c=1.0)`
3632
```
3733
κ(x,y) = c
3834
```
3935
Kernel function always returning a constant value `c`
4036
"""
41-
struct ConstantKernel{Tr, Tc<:Real} <: Kernel{Tr}
42-
transform::Tr
37+
struct ConstantKernel{Tc<:Real} <: BaseKernel
4338
c::Tc
39+
function ConstantKernel(;c::T=1.0) where {T<:Real}
40+
new{T}(c)
41+
end
4442
end
4543

46-
params(k::ConstantKernel) = (params(k.transform),k.c)
47-
opt_params(k::ConstantKernel) = (opt_params(k.transform),k.c)
48-
49-
ConstantKernel(c::Real=1.0) = ConstantKernel(IdentityTransform(),c)
50-
51-
ConstantKernel(t::Tr,c::Tc=1.0) where {Tr<:Transform,Tc<:Real} = ConstantKernel{Tr,Tc}(t,c)
44+
params(k::ConstantKernel) = (k.c,)
45+
opt_params(k::ConstantKernel) = (k.c,)
5246

53-
@inline kappa::ConstantKernel,x::Real) = κ.c
47+
kappa::ConstantKernel,x::Real) = κ.c*one(x)
5448

5549
metric(::ConstantKernel) = Delta()

src/kernels/exponential.jl

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
"""
2-
`SqExponentialKernel([ρ=1.0])`
2+
`SqExponentialKernel()`
33
44
The squared exponential kernel is an isotropic Mercer kernel given by the formula:
55
```
6-
κ(x,y) = exp(-ρ²‖x-y‖²)
6+
κ(x,y) = exp(-‖x-y‖²)
77
```
88
See also [`ExponentialKernel`](@ref) for a
99
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
1010
"""
11-
struct SqExponentialKernel{Tr} <: Kernel{Tr}
12-
transform::Tr
13-
end
11+
struct SqExponentialKernel <: BaseKernel end
1412

15-
@inline kappa::SqExponentialKernel, d²::Real) = exp(-d²)
16-
@inline iskroncompatible(::SqExponentialKernel) = true
13+
kappa::SqExponentialKernel, d²::Real) = exp(-d²)
14+
iskroncompatible(::SqExponentialKernel) = true
1715

1816
metric(::SqExponentialKernel) = SqEuclidean()
1917

18+
Base.show(io::IO,::SqExponentialKernel) = print(io,"Squared Exponential Kernel")
19+
2020
## Aliases ##
2121
const RBFKernel = SqExponentialKernel
2222
const GaussianKernel = SqExponentialKernel
@@ -28,14 +28,14 @@ The exponential kernel is an isotropic Mercer kernel given by the formula:
2828
κ(x,y) = exp(-ρ‖x-y‖)
2929
```
3030
"""
31-
struct ExponentialKernel{Tr} <: Kernel{Tr}
32-
transform::Tr
33-
end
31+
struct ExponentialKernel <: BaseKernel end
3432

35-
@inline kappa::ExponentialKernel, d::Real) = exp(-d)
36-
@inline iskroncompatible(::ExponentialKernel) = true
33+
kappa::ExponentialKernel, d::Real) = exp(-d)
34+
iskroncompatible(::ExponentialKernel) = true
3735
metric(::ExponentialKernel) = Euclidean()
3836

37+
Base.show(io::IO,::ExponentialKernel) = print(io,"Exponential Kernel")
38+
3939
## Alias ##
4040
const LaplacianKernel = ExponentialKernel
4141

@@ -46,30 +46,17 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
4646
κ(x,y) = exp(-ρ^(2γ)‖x-y‖^(2γ))
4747
```
4848
"""
49-
struct GammaExponentialKernel{Tr, Tγ<:Real} <: Kernel{Tr}
50-
transform::Tr
49+
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
5150
γ::Tγ
52-
function GammaExponentialKernel{Tr,Tγ}(t::Tr, γ::Tγ) where {Tr<:Transform,Tγ<:Real}
53-
@check_args(GammaExponentialKernel, γ, γ >= zero(), "γ > 0")
54-
return new{Tr, Tγ}(t, γ)
51+
function GammaExponentialKernel(;γ::T=2.0) where {T<:Real}
52+
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")
53+
return new{T}(γ)
5554
end
5655
end
5756

58-
params(k::GammaExponentialKernel) = (params(transform),γ)
59-
opt_params(k::GammaExponentialKernel) = (opt_params(transform),γ)
60-
61-
function GammaExponentialKernel::Real=1.0, γ::Real=2.0)
62-
GammaExponentialKernel(ScaleTransform(ρ), γ)
63-
end
64-
65-
function GammaExponentialKernel::AbstractVector{<:Real}, γ::Real=2.0)
66-
GammaExponentialKernel(ARDTransform(ρ), γ)
67-
end
68-
69-
function GammaExponentialKernel(t::Tr, γ::Tγ=2.0) where {Tr<:Transform, Tγ<:Real}
70-
GammaExponentialKernel{Tr, Tγ}(t, γ)
71-
end
57+
params(k::GammaExponentialKernel) = (γ,)
58+
opt_params(k::GammaExponentialKernel) = (γ,)
7259

73-
@inline kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)
74-
@inline iskroncompatible(::GammaExponentialKernel) = true
60+
kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)
61+
iskroncompatible(::GammaExponentialKernel) = true
7562
metric(::GammaExponentialKernel) = SqEuclidean()

0 commit comments

Comments
 (0)