Skip to content

params(), Flux/Zygote style #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ julia = "1.0"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]
1 change: 1 addition & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ include("zygote_adjoints.jl")
function __init__()
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
end

end
9 changes: 3 additions & 6 deletions src/kernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@ metric(::WhiteKernel) = Delta()
Kernel function always returning a constant value `c`
"""
struct ConstantKernel{Tc<:Real} <: BaseKernel
c::Tc
c::Vector{Tc}
function ConstantKernel(;c::T=1.0) where {T<:Real}
new{T}(c)
new{T}([c])
end
end

params(k::ConstantKernel) = (k.c,)
opt_params(k::ConstantKernel) = (k.c,)

kappa(κ::ConstantKernel,x::Real) = κ.c*one(x)
kappa(κ::ConstantKernel,x::Real) = first(κ.c)*one(x)

metric(::ConstantKernel) = Delta()
9 changes: 3 additions & 6 deletions src/kernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
```
"""
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
γ::
γ::Vector{Tγ}
function GammaExponentialKernel(;γ::T=2.0) where {T<:Real}
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")
return new{T}(γ)
return new{T}([γ])
end
end

params(k::GammaExponentialKernel) = (γ,)
opt_params(k::GammaExponentialKernel) = (γ,)

kappa(κ::GammaExponentialKernel, d²::Real) = exp(-d²^κ.γ)
kappa(κ::GammaExponentialKernel, d²::Real) = exp(-d²^first(κ.γ))
iskroncompatible(::GammaExponentialKernel) = true
metric(::GammaExponentialKernel) = SqEuclidean()
3 changes: 0 additions & 3 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ struct KernelProduct <: Kernel
kernels::Vector{Kernel}
end

params(k::KernelProduct) = params.(k.kernels)
opt_params(k::KernelProduct) = opt_params.(k.kernels)

Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
Expand Down
3 changes: 0 additions & 3 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ function KernelSum(
return KernelSum(kernels, weights)
end

params(k::KernelSum) = (k.weights, params.(k.kernels))
opt_params(k::KernelSum) = (k.weights, opt_params.(k.kernels))

Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)])
Base.:+(k1::KernelSum, k2::KernelSum) =
Expand Down
12 changes: 6 additions & 6 deletions src/kernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ The matern kernel is an isotropic Mercer kernel given by the formula:
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
"""
struct MaternKernel{Tν<:Real} <: BaseKernel
ν::
ν::Vector{Tν}
function MaternKernel(;ν::T=1.5) where {T<:Real}
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
return new{T}(ν)
return new{T}([ν])
end
end

params(k::MaternKernel) = (k.ν,)
opt_params(k::MaternKernel) = (k.ν,)

@inline kappa(κ::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((one(d)-κ.ν)*logtwo-logabsgamma(κ.ν)[1] + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk(κ.ν,sqrt(2κ.ν)*d)))
@inline function kappa(κ::MaternKernel, d::Real)
ν = first(κ.ν)
iszero(d) ? one(d) : exp((one(d)-ν)*logtwo-logabsgamma(ν)[1] + ν*log(sqrt(2ν)*d)+log(besselk(ν,sqrt(2ν)*d)))
end

metric(::MaternKernel) = Euclidean()

Expand Down
22 changes: 8 additions & 14 deletions src/kernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ The linear kernel is a Mercer kernel given by
Where `c` is a real number
"""
struct LinearKernel{Tc<:Real} <: BaseKernel
c::Tc
c::Vector{Tc}
function LinearKernel(;c::T=0.0) where {T}
new{T}(c)
new{T}([c])
end
end

params(k::LinearKernel) = (k.c,)
opt_params(k::LinearKernel) = (k.c,)

kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + κ.c
kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c)

metric(::LinearKernel) = DotProduct()

Expand All @@ -28,18 +25,15 @@ The polynomial kernel is a Mercer kernel given by
```
Where `c` is a real number, and `d` is a shape parameter bigger than 1
"""
struct PolynomialKernel{Td<:Real,Tc<:Real} <: BaseKernel
d::Td
c::Tc
struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
d::Vector{Td}
c::Vector{Tc}
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
return new{Td, Tc}(d, c)
return new{Td, Tc}([d], [c])
end
end

params(k::PolynomialKernel) = (k.d,k.c)
opt_params(k::PolynomialKernel) = (k.d,k.c)

kappa(κ::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^(κ.d)
kappa(κ::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first(κ.c))^(first(κ.d))

metric(::PolynomialKernel) = DotProduct()
20 changes: 7 additions & 13 deletions src/kernels/rationalquad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@ The rational-quadratic kernel is an isotropic Mercer kernel given by the formula
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
"""
struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
α::
α::Vector{Tα}
function RationalQuadraticKernel(;α::T=2.0) where {T}
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
return new{T}(α)
return new{T}([α])
end
end

params(k::RationalQuadraticKernel) = (k.α,)
opt_params(k::RationalQuadraticKernel) = (k.α,)

kappa(κ::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²/κ.α)^(-κ.α)
kappa(κ::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²/first(κ.α))^(-first(κ.α))

metric(::RationalQuadraticKernel) = SqEuclidean()

Expand All @@ -30,18 +27,15 @@ The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the f
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
"""
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
α::
γ::
α::Vector{Tα}
γ::Vector{Tγ}
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2.0) where {Tα<:Real, Tγ<:Real}
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
return new{Tα, Tγ}(α, γ)
return new{Tα, Tγ}([α], [γ])
end
end

params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
opt_params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)

kappa(κ::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²^κ.γ/κ.α)^(-κ.α)
kappa(κ::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²^first(κ.γ)/first(κ.α))^(-first(κ.α))

metric(::GammaRationalQuadraticKernel) = SqEuclidean()
3 changes: 0 additions & 3 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)

metric(k::ScaledKernel) = metric(k.kernel)

params(k::ScaledKernel) = (k.σ², params(k.kernel))
opt_params(k::ScaledKernel) = (k.σ², opt_params(k.kernel))

Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)

Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)
Expand Down
2 changes: 0 additions & 2 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ kappa(κ::TransformedKernel, x) = kappa(κ.kernel, x)

metric(κ::TransformedKernel) = metric(κ.kernel)

params(κ::TransformedKernel) = (params(κ.transform),params(κ.kernel))

Base.show(io::IO,κ::TransformedKernel) = printshifted(io,κ,0)

function printshifted(io::IO,κ::TransformedKernel,shift::Int)
Expand Down
39 changes: 39 additions & 0 deletions src/trainable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import .Flux.trainable

### Base Kernels

trainable(k::ConstantKernel) = (k.c,)

trainable(k::GammaExponentialKernel) = (k.γ,)

trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)

trainable(k::MaternKernel) = (k.ν,)

trainable(k::LinearKernel) = (k.c,)

trainable(k::PolynomialKernel) = (k.d, k.c)

trainable(k::RationalQuadraticKernel) = (k.α,)

#### Composite kernels

trainable(κ::KernelProduct) = κ.kernels

trainable(κ::KernelSum) = (κ.weights, κ.kernels) #To check

trainable(κ::ScaledKernel) = (κ.σ², κ.kernel)

trainable(κ::TransformedKernel) = (κ.transform, κ.kernel)

### Transforms

trainable(t::ARDTransform) = (t.v,)

trainable(t::ChainTransform) = t.transforms

trainable(t::FunctionTransform) = (t.f,)

trainable(t::LowRankTransform) = (t.proj,)

trainable(t::ScaleTransform) = (t.s,)
1 change: 0 additions & 1 deletion src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ function set!(t::ARDTransform{T},ρ::AbstractVector{T}) where {T<:Real}
t.v .= ρ
end

params(t::ARDTransform) = t.v
dim(t::ARDTransform) = length(t.v)

function apply(t::ARDTransform,X::AbstractMatrix{<:Real};obsdim::Int = defaultobs)
Expand Down
2 changes: 0 additions & 2 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ function apply(t::ChainTransform,X::T;obsdim::Int=defaultobs) where {T}
end

set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
params(t::ChainTransform) = (params.(t.transforms))
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))


Base.:∘(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
Base.:∘(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test
Base.:∘(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
3 changes: 1 addition & 2 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FunctionTransform
f(x) = abs.(x)
tr = FunctionTransform(f)
```
Take a function `f` as an argument which is going to act on each vector individually.
Take a function or object `f` as an argument which is going to act on each vector individually.
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
"""
Expand All @@ -15,4 +15,3 @@ end
apply(t::FunctionTransform, X::T; obsdim::Int = defaultobs) where {T} = mapslices(t.f, X, dims = feature_dim(obsdim))

duplicate(t::FunctionTransform,f) = FunctionTransform(f)
params(t::FunctionTransform) = t.f
1 change: 0 additions & 1 deletion src/transform/lowranktransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) whe
t.proj .= M
end

params(t::LowRankTransform) = t.proj

Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test
Expand Down
1 change: 0 additions & 1 deletion src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
end

set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ]
params(t::ScaleTransform) = t.s
dim(str::ScaleTransform) = 1

apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s) * x
Expand Down
3 changes: 0 additions & 3 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ end

set!(t::SelectTransform{<:AbstractVector{T}},dims::AbstractVector{T}) where {T<:Int} = t.select .= dims

params(t::SelectTransform) = t.select

duplicate(t::SelectTransform,θ) = t


Base.maximum(t::SelectTransform) = maximum(t.select)

function apply(t::SelectTransform, X::AbstractMatrix{<:Real}; obsdim::Int = defaultobs)
Expand Down
14 changes: 9 additions & 5 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ include("functiontransform.jl")
include("selecttransform.jl")
include("chaintransform.jl")

"""
`apply(t::Transform, x; obsdim::Int=defaultobs)`
Apply the transform `t` per slice on the array `x`
"""
apply

"""
IdentityTransform
Return exactly the input
"""
struct IdentityTransform <: Transform end

params(t::IdentityTransform) = nothing

apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x

### TODO Maybe defining adjoints could help but so far it's not working

Expand All @@ -32,9 +36,9 @@ apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test

# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = transform(t,x),Δ->(ScaleTransform(nothing),t.s.*Δ)
#
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
# @adjoint transform(t::ARDTransform{<:Real},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
# @show Δ,size(Δ);
# return (obsdim == 1 ? ScaleTransform()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
# return (obsdim == 1 ? ARD()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
# end
#
# @adjoint transform(t::ScaleTransform{T},x::AbstractVecOrMat,obsdim::Int) where {T<:Real} = transform(t,x), Δ->(ScaleTransform(one(T)),t.s.*Δ,nothing)
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ base_kernel(k::Kernel) = eval(nameof(typeof(k)))
base_transform(t::Transform) = eval(nameof(typeof(t)))

"""
Will be implemented at some point
```julia
params(k::Kernel)
params(t::Transform)
```
For a kernel return a tuple with parameters of the transform followed by the specific parameters of the kernel
For a transform return its parameters, for a `ChainTransform` return a vector of `params(t)`.
"""
params(k::Kernel) = (params(transform(k)),)
#params
Loading