Skip to content

Commit fe5487e

Browse files
committed
Corrected transform behavior as a constructor
1 parent a267406 commit fe5487e

File tree

8 files changed

+36
-17
lines changed

8 files changed

+36
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*.json
2+
*.cov
23
Manifest.toml
34
coverage/

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ The aim is to make the API as model-agnostic as possible while still being user-
1313
```julia
1414
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
1515
# Set simple scaling of the data
16-
k₁ = sqexponentialkernel(1.0)
16+
k₁ = SqExponentialKernel()
1717
K₁ = kernelmatrix(k₁,X,obsdim=1)
1818

1919
# Set a function transformation on the data
2020
k₂ = TransformedKernel(Matern32Kernel(),FunctionTransform(x->sin.(x)))
2121
K₂ = kernelmatrix(k₂,X,obsdim=1)
2222

2323
# Set a matrix premultiplication on the data
24-
k₃ = polynomialkernel(LowRankTransform(randn(4,1)),2.0,0.0)
24+
k₃ = transform(PolynomialKernel(c=2.0,d=2.0),LowRankTransform(randn(4,1)))
2525
K₃ = kernelmatrix(k₃,X,obsdim=1)
2626

2727
# Add and sum kernels
28-
k₄ = 0.5*SqExponentialKernel()*linearkernel(0.5) + 0.4*k₂
28+
k₄ = 0.5*SqExponentialKernel()*LinearKernel(c=0.5) + 0.4*k₂
2929
K₄ = kernelmatrix(k₄,X,obsdim=1)
3030

3131
plot(heatmap.([K₁,K₂,K₃,K₄],yflip=true,colorbar=false)...,layout=(2,2),title=["K₁" "K₂" "K₃" "K₄"])

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
4+
export transform
45
export params, duplicate, set! # Helpers
56

67
export Kernel

src/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626

2727
for k in nameof.(subtypes(BaseKernel))
2828
@eval begin
29-
@deprecate($k::Real;args...),TransformedKernel($k(args...),ScaleTransform(ρ)))
30-
@deprecate($k::AbstractVector{<:Real};args...),TransformedKernel($k(args...),ARDTransform(ρ)))
29+
@deprecate($k::Real;args...),transform($k(args...),ρ))
30+
@deprecate($k::AbstractVector{<:Real};args...),transform($k(args...),ρ))
3131
end
3232
end

src/kernels/transformedkernel.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
33
transform::Tr
44
end
55

6+
"""
7+
```julia
8+
transform(k::BaseKernel, t::Transform) (1)
9+
transform(k::BaseKernel, ρ::Real) (2)
10+
transform(k::BaseKernel, ρ::AbstractVector) (3)
11+
```
12+
(1) Create a TransformedKernel with transform `t` and kernel `k`
13+
(2) Same as (1) with a `ScaleTransform` with scale `ρ`
14+
(3) Same as (1) with an `ARDTransform` with scales `ρ`
15+
"""
16+
transform
17+
18+
transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t)
19+
20+
transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))
21+
22+
transform(k::BaseKernel::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))
23+
624
kernel(κ) = κ.kernel
725

826
kappa::TransformedKernel, x) = kappa.kernel, x)

src/transform/transform.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
export Transform, IdentityTransform, ScaleTransform, ARDTransform, LowRankTransform, FunctionTransform, ChainTransform
22

3-
"""
4-
```julia
5-
transform(t::Transform, X::AbstractMatrix)
6-
```
7-
Apply the transfomration `t` or `k.transform` on the input `X`
8-
"""
9-
transform
10-
113
include("scaletransform.jl")
124
include("ardtransform.jl")
135
include("lowranktransform.jl")

test/test_kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
1212
Kdiag = [zeros(dims[1]),zeros(dims[2])]
1313
s = rand()
1414
k = SqExponentialKernel()
15-
kt = sqexponentialkernel(s)
15+
kt = transform(SqExponentialKernel(),s)
1616
@testset "Kernel Matrix Operations" begin
1717
@testset "Inplace Kernel Matrix" begin
1818
for obsdim in [1,2]

test/test_kernels.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,17 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
120120
end
121121
@testset "Transformed/Scaled Kernel" begin
122122
s = rand()
123+
v = rand(3)
123124
k = SqExponentialKernel()
124-
kt = KernelFunctions.TransformedKernel(k,ScaleTransform(s))
125-
ks = KernelFunctions.ScaledKernel(k,s)
126-
@test KernelFunctions.kappa(kt,v1,v2) == KernelFunctions.kappa(KernelFunctions.transform(k,ScaleTransform(s)),v1,v2)
125+
kt = TransformedKernel(k,ScaleTransform(s))
126+
ktard = TransformedKernel(k,ARDTransform(v))
127+
ks = ScaledKernel(k,s)
128+
@test kappa(kt,v1,v2) == kappa(transform(k,ScaleTransform(s)),v1,v2)
129+
@test kappa(kt,v1,v2) == kappa(transform(k,s),v1,v2)
130+
@test kappa(kt,v1,v2) == kappa(k,s*v1,s*v2)
131+
@test kappa(ktard,v1,v2) == kappa(transform(k,ARDTransform(v)),v1,v2)
132+
@test kappa(ktard,v1,v2) == kappa(transform(k,v),v1,v2)
133+
@test kappa(ktard,v1,v2) == kappa(k,v.*v1,v.*v2)
127134
@test KernelFunctions.metric(kt) == KernelFunctions.metric(k)
128135
@test kappa(ks,x) == s*kappa(k,x)
129136
@test kappa(ks,x) == kappa(s*k,x)

0 commit comments

Comments
 (0)