Skip to content

Commit a0daa6c

Browse files
authored
Merge pull request #41 from theogf/params
params(), Flux/Zygote style
2 parents 716285c + 3e47144 commit a0daa6c

22 files changed

+133
-74
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ julia = "1.0"
2525

2626
[extras]
2727
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2829
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
2930
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
3031
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[targets]
35-
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]
36+
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include("zygote_adjoints.jl")
5858
function __init__()
5959
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
6060
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
61+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
6162
end
6263

6364
end

src/kernels/constant.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ metric(::WhiteKernel) = Delta()
3535
Kernel function always returning a constant value `c`
3636
"""
3737
struct ConstantKernel{Tc<:Real} <: BaseKernel
38-
c::Tc
38+
c::Vector{Tc}
3939
function ConstantKernel(;c::T=1.0) where {T<:Real}
40-
new{T}(c)
40+
new{T}([c])
4141
end
4242
end
4343

44-
params(k::ConstantKernel) = (k.c,)
45-
opt_params(k::ConstantKernel) = (k.c,)
46-
47-
kappa::ConstantKernel,x::Real) = κ.c*one(x)
44+
kappa::ConstantKernel,x::Real) = first.c)*one(x)
4845

4946
metric(::ConstantKernel) = Delta()

src/kernels/exponential.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,13 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
4747
```
4848
"""
4949
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
50-
γ::Tγ
50+
γ::Vector{Tγ}
5151
function GammaExponentialKernel(;γ::T=2.0) where {T<:Real}
5252
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")
53-
return new{T}(γ)
53+
return new{T}([γ])
5454
end
5555
end
5656

57-
params(k::GammaExponentialKernel) = (γ,)
58-
opt_params(k::GammaExponentialKernel) = (γ,)
59-
60-
kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)
57+
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6158
iskroncompatible(::GammaExponentialKernel) = true
6259
metric(::GammaExponentialKernel) = SqEuclidean()

src/kernels/kernelproduct.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ struct KernelProduct <: Kernel
1414
kernels::Vector{Kernel}
1515
end
1616

17-
params(k::KernelProduct) = params.(k.kernels)
18-
opt_params(k::KernelProduct) = opt_params.(k.kernels)
19-
2017
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2118
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
2219
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))

src/kernels/kernelsum.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ function KernelSum(
2525
return KernelSum(kernels, weights)
2626
end
2727

28-
params(k::KernelSum) = (k.weights, params.(k.kernels))
29-
opt_params(k::KernelSum) = (k.weights, opt_params.(k.kernels))
30-
3128
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
3229
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)])
3330
Base.:+(k1::KernelSum, k2::KernelSum) =

src/kernels/matern.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ The matern kernel is an isotropic Mercer kernel given by the formula:
77
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
88
"""
99
struct MaternKernel{Tν<:Real} <: BaseKernel
10-
ν::Tν
10+
ν::Vector{Tν}
1111
function MaternKernel(;ν::T=1.5) where {T<:Real}
1212
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
13-
return new{T}(ν)
13+
return new{T}([ν])
1414
end
1515
end
1616

17-
params(k::MaternKernel) = (k.ν,)
18-
opt_params(k::MaternKernel) = (k.ν,)
19-
20-
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((one(d)-κ.ν)*logtwo-logabsgamma.ν)[1] + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
17+
@inline function kappa::MaternKernel, d::Real)
18+
ν = first.ν)
19+
iszero(d) ? one(d) : exp((one(d)-ν)*logtwo-logabsgamma(ν)[1] + ν*log(sqrt(2ν)*d)+log(besselk(ν,sqrt(2ν)*d)))
20+
end
2121

2222
metric(::MaternKernel) = Euclidean()
2323

src/kernels/polynomial.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@ The linear kernel is a Mercer kernel given by
77
Where `c` is a real number
88
"""
99
struct LinearKernel{Tc<:Real} <: BaseKernel
10-
c::Tc
10+
c::Vector{Tc}
1111
function LinearKernel(;c::T=0.0) where {T}
12-
new{T}(c)
12+
new{T}([c])
1313
end
1414
end
1515

16-
params(k::LinearKernel) = (k.c,)
17-
opt_params(k::LinearKernel) = (k.c,)
18-
19-
kappa::LinearKernel, xᵀy::Real) = xᵀy + κ.c
16+
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
2017

2118
metric(::LinearKernel) = DotProduct()
2219

@@ -28,18 +25,15 @@ The polynomial kernel is a Mercer kernel given by
2825
```
2926
Where `c` is a real number, and `d` is a shape parameter bigger than 1
3027
"""
31-
struct PolynomialKernel{Td<:Real,Tc<:Real} <: BaseKernel
32-
d::Td
33-
c::Tc
28+
struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
29+
d::Vector{Td}
30+
c::Vector{Tc}
3431
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}
3532
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
36-
return new{Td, Tc}(d, c)
33+
return new{Td, Tc}([d], [c])
3734
end
3835
end
3936

40-
params(k::PolynomialKernel) = (k.d,k.c)
41-
opt_params(k::PolynomialKernel) = (k.d,k.c)
42-
43-
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^.d)
37+
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first.c))^(first.d))
4438

4539
metric(::PolynomialKernel) = DotProduct()

src/kernels/rationalquad.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@ The rational-quadratic kernel is an isotropic Mercer kernel given by the formula
77
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
88
"""
99
struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
10-
α::Tα
10+
α::Vector{Tα}
1111
function RationalQuadraticKernel(;α::T=2.0) where {T}
1212
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
13-
return new{T}(α)
13+
return new{T}([α])
1414
end
1515
end
1616

17-
params(k::RationalQuadraticKernel) = (k.α,)
18-
opt_params(k::RationalQuadraticKernel) = (k.α,)
19-
20-
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/κ.α)^(-κ.α)
17+
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
2118

2219
metric(::RationalQuadraticKernel) = SqEuclidean()
2320

@@ -30,18 +27,15 @@ The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the f
3027
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
3128
"""
3229
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
33-
α::Tα
34-
γ::Tγ
30+
α::Vector{Tα}
31+
γ::Vector{Tγ}
3532
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2.0) where {Tα<:Real, Tγ<:Real}
3633
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
3734
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
38-
return new{Tα, Tγ}(α, γ)
35+
return new{Tα, Tγ}([α], [γ])
3936
end
4037
end
4138

42-
params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
43-
opt_params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
44-
45-
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)
39+
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
4640

4741
metric(::GammaRationalQuadraticKernel) = SqEuclidean()

src/kernels/scaledkernel.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)
1212

1313
metric(k::ScaledKernel) = metric(k.kernel)
1414

15-
params(k::ScaledKernel) = (k.σ², params(k.kernel))
16-
opt_params(k::ScaledKernel) = (k.σ², opt_params(k.kernel))
17-
1815
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
1916

2017
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

src/kernels/transformedkernel.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ kappa(κ::TransformedKernel, x) = kappa(κ.kernel, x)
2727

2828
metric::TransformedKernel) = metric.kernel)
2929

30-
params::TransformedKernel) = (params.transform),params.kernel))
31-
3230
Base.show(io::IO::TransformedKernel) = printshifted(io,κ,0)
3331

3432
function printshifted(io::IO::TransformedKernel,shift::Int)

src/trainable.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import .Flux.trainable
2+
3+
### Base Kernels
4+
5+
trainable(k::ConstantKernel) = (k.c,)
6+
7+
trainable(k::GammaExponentialKernel) = (k.γ,)
8+
9+
trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)
10+
11+
trainable(k::MaternKernel) = (k.ν,)
12+
13+
trainable(k::LinearKernel) = (k.c,)
14+
15+
trainable(k::PolynomialKernel) = (k.d, k.c)
16+
17+
trainable(k::RationalQuadraticKernel) = (k.α,)
18+
19+
#### Composite kernels
20+
21+
trainable::KernelProduct) = κ.kernels
22+
23+
trainable::KernelSum) =.weights, κ.kernels) #To check
24+
25+
trainable::ScaledKernel) =.σ², κ.kernel)
26+
27+
trainable::TransformedKernel) =.transform, κ.kernel)
28+
29+
### Transforms
30+
31+
trainable(t::ARDTransform) = (t.v,)
32+
33+
trainable(t::ChainTransform) = t.transforms
34+
35+
trainable(t::FunctionTransform) = (t.f,)
36+
37+
trainable(t::LowRankTransform) = (t.proj,)
38+
39+
trainable(t::ScaleTransform) = (t.s,)

src/transform/ardtransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ function set!(t::ARDTransform{T},ρ::AbstractVector{T}) where {T<:Real}
2525
t.v .= ρ
2626
end
2727

28-
params(t::ARDTransform) = t.v
2928
dim(t::ARDTransform) = length(t.v)
3029

3130
function apply(t::ARDTransform,X::AbstractMatrix{<:Real};obsdim::Int = defaultobs)

src/transform/chaintransform.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ function apply(t::ChainTransform,X::T;obsdim::Int=defaultobs) where {T}
3232
end
3333

3434
set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
35-
params(t::ChainTransform) = (params.(t.transforms))
3635
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
3736

38-
3937
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
4038
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test
4139
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))

src/transform/functiontransform.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ FunctionTransform
44
f(x) = abs.(x)
55
tr = FunctionTransform(f)
66
```
7-
Take a function `f` as an argument which is going to act on each vector individually.
7+
Take a function or object `f` as an argument which is going to act on each vector individually.
88
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
99
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
1010
"""
@@ -15,4 +15,3 @@ end
1515
apply(t::FunctionTransform, X::T; obsdim::Int = defaultobs) where {T} = mapslices(t.f, X, dims = feature_dim(obsdim))
1616

1717
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
18-
params(t::FunctionTransform) = t.f

src/transform/lowranktransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) whe
1616
t.proj .= M
1717
end
1818

19-
params(t::LowRankTransform) = t.proj
2019

2120
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
2221
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test

src/transform/scaletransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1616
end
1717

1818
set!(t::ScaleTransform::Real) = t.s .= [ρ]
19-
params(t::ScaleTransform) = t.s
2019
dim(str::ScaleTransform) = 1
2120

2221
apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s) * x

src/transform/selecttransform.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@ end
2222

2323
set!(t::SelectTransform{<:AbstractVector{T}},dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
2424

25-
params(t::SelectTransform) = t.select
26-
2725
duplicate(t::SelectTransform,θ) = t
2826

29-
3027
Base.maximum(t::SelectTransform) = maximum(t.select)
3128

3229
function apply(t::SelectTransform, X::AbstractMatrix{<:Real}; obsdim::Int = defaultobs)

src/transform/transform.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ include("functiontransform.jl")
77
include("selecttransform.jl")
88
include("chaintransform.jl")
99

10+
"""
11+
`apply(t::Transform, x; obsdim::Int=defaultobs)`
12+
Apply the transform `t` per slice on the array `x`
13+
"""
14+
apply
15+
1016
"""
1117
IdentityTransform
1218
Return exactly the input
1319
"""
1420
struct IdentityTransform <: Transform end
1521

16-
params(t::IdentityTransform) = nothing
17-
18-
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test
22+
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x
1923

2024
### TODO Maybe defining adjoints could help but so far it's not working
2125

@@ -32,9 +36,9 @@ apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test
3236

3337
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = transform(t,x),Δ->(ScaleTransform(nothing),t.s.*Δ)
3438
#
35-
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
39+
# @adjoint transform(t::ARDTransform{<:Real},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
3640
# @show Δ,size(Δ);
37-
# return (obsdim == 1 ? ScaleTransform()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
41+
# return (obsdim == 1 ? ARD()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
3842
# end
3943
#
4044
# @adjoint transform(t::ScaleTransform{T},x::AbstractVecOrMat,obsdim::Int) where {T<:Real} = transform(t,x), Δ->(ScaleTransform(one(T)),t.s.*Δ,nothing)

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ base_kernel(k::Kernel) = eval(nameof(typeof(k)))
3131
base_transform(t::Transform) = eval(nameof(typeof(t)))
3232

3333
"""
34+
Will be implemented at some point
3435
```julia
3536
params(k::Kernel)
3637
params(t::Transform)
3738
```
3839
For a kernel return a tuple with parameters of the transform followed by the specific parameters of the kernel
3940
For a transform return its parameters, for a `ChainTransform` return a vector of `params(t)`.
4041
"""
41-
params(k::Kernel) = (params(transform(k)),)
42+
#params

0 commit comments

Comments
 (0)