Skip to content

Commit 96086dd

Browse files
authored
Merge pull request #157 from devmotion/functors
2 parents a92fcc2 + 1bba715 commit 96086dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+133
-196
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.6.1"
3+
version = "0.7.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1617
[compat]
1718
Compat = "2.2, 3"
1819
Distances = "0.9"
20+
Functors = "0.1"
1921
Requires = "1.0.1"
2022
SpecialFunctions = "0.8, 0.9, 0.10"
2123
StatsBase = "0.32, 0.33"

docs/src/create_kernel.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,43 @@ Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slow
3333
### Additional Options
3434

3535
Finally there are additional functions you can define to bring in more features:
36-
- `KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
37-
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
3836
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
3937
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
4038
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
4139
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
4240
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel
41+
42+
KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
43+
in a way that is compatible with the [Flux ML framework](https://github.com/FluxML/Flux.jl).
44+
You can use `Functors.@functor` if all fields of your kernel struct are trainable. Note that optimization algorithms
45+
in Flux are not compatible with scalar parameters (yet), and hence vector-valued parameters should be preferred.
46+
47+
```julia
48+
import Functors
49+
50+
struct MyKernel{T} <: KernelFunctions.Kernel
51+
a::Vector{T}
52+
end
53+
54+
Functors.@functor MyKernel
55+
```
56+
57+
If only a subset of the fields are trainable, you have to specify explicitly how to (re)construct the kernel with
58+
modified parameter values by [implementing `Functors.functor(::Type{<:MyKernel}, x)` for your kernel struct](https://github.com/FluxML/Functors.jl/issues/3):
59+
60+
```julia
61+
import Functors
62+
63+
struct MyKernel{T} <: KernelFunctions.Kernel
64+
n::Int
65+
a::Vector{T}
66+
end
67+
68+
function Functors.functor(::Type{<:MyKernel}, x::MyKernel)
69+
function reconstruct_mykernel(xs)
70+
# keep field `n` of the original kernel and set `a` to (possibly different) `xs.a`
71+
return MyKernel(x.n, xs.a)
72+
end
73+
return (a = x.a,), reconstruct_mykernel
74+
end
75+
```

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export IndependentMOKernel
3737
using Compat
3838
using Requires
3939
using Distances, LinearAlgebra
40+
using Functors
4041
using SpecialFunctions: loggamma, besselk, polygamma
4142
using ZygoteRules: @adjoint, pullback
4243
using StatsFuns: logtwo
@@ -79,7 +80,6 @@ include("zygote_adjoints.jl")
7980
function __init__()
8081
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
8182
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
82-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
8383
end
8484

8585
end

src/basekernels/constant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ metric(::ZeroKernel) = Delta()
1515

1616
Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
1717

18-
1918
"""
2019
WhiteKernel()
2120
@@ -55,6 +54,8 @@ struct ConstantKernel{Tc<:Real} <: SimpleKernel
5554
end
5655
end
5756

57+
@functor ConstantKernel
58+
5859
kappa::ConstantKernel,x::Real) = first.c)*one(x)
5960

6061
metric(::ConstantKernel) = Delta()

src/basekernels/exponential.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
6363
end
6464
end
6565

66+
@functor GammaExponentialKernel
67+
6668
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6769

6870
metric(::GammaExponentialKernel) = SqEuclidean()

src/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct FBMKernel{T<:Real} <: Kernel
1717
end
1818
end
1919

20+
@functor FBMKernel
21+
2022
function::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
2123
modX = sum(abs2, x)
2224
modY = sum(abs2, y)

src/basekernels/gabor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct GaborKernel{K<:Kernel} <: Kernel
1515
end
1616
end
1717

18+
@functor GaborKernel
19+
1820
::GaborKernel)(x, y) = κ.kernel(x ,y)
1921

2022
function _gabor(; ell = nothing, p = nothing)

src/basekernels/maha.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1616
end
1717
end
1818

19+
@functor MahalanobisKernel
20+
1921
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
2022

2123
metric::MahalanobisKernel) = SqMahalanobis.P)

src/basekernels/matern.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct MaternKernel{Tν<:Real} <: SimpleKernel
1515
end
1616
end
1717

18+
@functor MaternKernel
19+
1820
@inline function kappa::MaternKernel, d::Real)
1921
result = _matern(first.ν), d)
2022
return ifelse(iszero(d), one(result), result)

src/basekernels/periodic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)
2020

2121
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
2222

23+
@functor PeriodicKernel
24+
2325
metric::PeriodicKernel) = Sinus.r)
2426

2527
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)

src/basekernels/piecewisepolynomial.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ function PiecewisePolynomialKernel(;v::Integer=0, maha::AbstractMatrix{<:Real})
2525
return PiecewisePolynomialKernel{v}(maha)
2626
end
2727

28+
# Have to reconstruct the type parameter
29+
# See also https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663
30+
function Functors.functor(::Type{<:PiecewisePolynomialKernel{V}}, x) where V
31+
function reconstruct_kernel(xs)
32+
return PiecewisePolynomialKernel{V}(xs.maha)
33+
end
34+
return (maha = x.maha,), reconstruct_kernel
35+
end
36+
2837
_f::PiecewisePolynomialKernel{0}, r, j) = 1
2938
_f::PiecewisePolynomialKernel{1}, r, j) = 1 + (j + 1) * r
3039
_f::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r.^2

src/basekernels/polynomial.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ struct LinearKernel{Tc<:Real} <: SimpleKernel
1414
end
1515
end
1616

17+
@functor LinearKernel
18+
1719
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
1820

1921
metric(::LinearKernel) = DotProduct()
@@ -38,6 +40,8 @@ struct PolynomialKernel{Td<:Real, Tc<:Real} <: SimpleKernel
3840
end
3941
end
4042

43+
@functor PolynomialKernel
44+
4145
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + first.c))^(first.d))
4246

4347
metric(::PolynomialKernel) = DotProduct()

src/basekernels/rationalquad.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct RationalQuadraticKernel{Tα<:Real} <: SimpleKernel
1515
end
1616
end
1717

18+
@functor RationalQuadraticKernel
19+
1820
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
1921
metric(::RationalQuadraticKernel) = SqEuclidean()
2022

@@ -38,6 +40,8 @@ struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: SimpleKernel
3840
end
3941
end
4042

43+
@functor GammaRationalQuadraticKernel
44+
4145
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
4246
metric(::GammaRationalQuadraticKernel) = SqEuclidean()
4347

src/kernels/kernelproduct.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ function KernelProduct(kernel::Kernel, kernels::Kernel...)
3939
return KernelProduct((kernel, kernels...))
4040
end
4141

42+
@functor KernelProduct
43+
4244
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct(k1, k2)
4345

4446
function Base.:*(

src/kernels/kernelsum.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ function KernelSum(kernel::Kernel, kernels::Kernel...)
3939
return KernelSum((kernel, kernels...))
4040
end
4141

42+
@functor KernelSum
43+
4244
Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)
4345

4446
function Base.:+(

src/kernels/scaledkernel.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real
1313
return ScaledKernel{Tk, Tσ²}(kernel, [σ²])
1414
end
1515

16+
@functor ScaledKernel
17+
1618
(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y)
1719

1820
function kernelmatrix::ScaledKernel, x::AbstractVector, y::AbstractVector)

src/kernels/tensorproduct.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ function TensorProduct(kernel::Kernel, kernels::Kernel...)
2020
return TensorProduct((kernel, kernels...))
2121
end
2222

23+
@functor TensorProduct
24+
2325
Base.length(kernel::TensorProduct) = length(kernel.kernels)
2426

2527
function (kernel::TensorProduct)(x, y)

src/kernels/transformedkernel.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
99
transform::Tr
1010
end
1111

12+
@functor TransformedKernel
13+
1214
(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))
1315

1416
# Optimizations for scale transforms of simple kernels to save allocations:

src/trainable.jl

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/transform/ardtransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ end
1414

1515
ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims))
1616

17+
@functor ARDTransform
18+
1719
function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real}
1820
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))"
1921
t.v .= ρ

src/transform/chaintransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform
1313
transforms::V
1414
end
1515

16+
@functor ChainTransform
17+
1618
Base.length(t::ChainTransform) = length(t.transforms)
1719

1820
# Constructor to create a chain transform with an array of parameters

src/transform/functiontransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ struct FunctionTransform{F} <: Transform
1313
f::F
1414
end
1515

16+
@functor FunctionTransform
17+
1618
(t::FunctionTransform)(x) = t.f(x)
1719

1820
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)

src/transform/lineartransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct LinearTransform{T<:AbstractMatrix{<:Real}} <: Transform
1717
A::T
1818
end
1919

20+
@functor LinearTransform
21+
2022
function set!(t::LinearTransform{<:AbstractMatrix{T}}, A::AbstractMatrix{T}) where {T<:Real}
2123
size(t.A) == size(A) ||
2224
error("size of the given matrix ", size(A), " and of the transformation matrix ",

src/transform/scaletransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1515
ScaleTransform{T}([s])
1616
end
1717

18+
@functor ScaleTransform
19+
1820
set!(t::ScaleTransform::Real) = t.s .= [ρ]
1921

2022
(t::ScaleTransform)(x) = first(t.s) * x

test/basekernels/constant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
@test metric(ConstantKernel()) == KernelFunctions.Delta()
2727
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
2828
@test repr(k) == "Constant Kernel (c = $(c))"
29+
test_params(k, ([c],))
2930
test_ADs(c->ConstantKernel(c=first(c)), [c])
3031
end
3132
end

test/basekernels/exponential.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
@test KernelFunctions.iskroncompatible(k) == true
4141
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
4242
@test_broken "Zygote gradient given γ"
43+
test_params(k, ([γ],))
4344
#Coherence :
4445
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)
4546
@test GammaExponentialKernel=0.5)(v1,v2) ExponentialKernel()(v1,v2)

test/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@
2424
@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
2525
test_ADs(FBMKernel, ADs = [:ReverseDiff])
2626
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"
27+
28+
test_params(k, ([h],))
2729
end

test/basekernels/maha.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@
1313
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
1414
# test_ADs(P -> MahalanobisKernel(P=P), P)
1515
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
16+
17+
test_params(k, (P,))
1618
end

test/basekernels/matern.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
@test repr(k) == "Matern Kernel (ν = $(ν))"
1717
# test_ADs(x->MaternKernel(nu=first(x)),[ν])
1818
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"
19+
test_params(k, ([ν],))
1920
end
2021
@testset "Matern32Kernel" begin
2122
k = Matern32Kernel()

test/basekernels/periodic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
@test repr(k) == "Periodic Kernel, length(r) = $(length(r)))"
1010
# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
1111
@test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff"
12+
test_params(k, (r,))
1213
end

test/basekernels/piecewisepolynomial.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,6 @@
3434
@test repr(k) == "Piecewise Polynomial Kernel (v = $(v), size(maha) = $(size(maha)))"
3535
# test_ADs(maha-> PiecewisePolynomialKernel(v=2, maha = maha), maha)
3636
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
37+
38+
test_params(k, (maha,))
3739
end

0 commit comments

Comments
 (0)