Skip to content

@kernel macro for creating kernel #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Compat = "2.2, 3"
MacroTools = "0.5"
Distances = "0.9"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10"
Expand Down
26 changes: 16 additions & 10 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@

To create a kernel chose one of the kernels proposed, see [Kernels](@ref), or create your own, see [Creating Kernels](@ref)
For example to create a square exponential kernel

```julia
k = SqExponentialKernel()
```
Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)) which are directly going to act on the inputs before passing them to the kernel.
For example to premultiply the input by 2.0 we create the kernel the following options are possible

Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [`Transform`](@ref)). The transformations are going to be applied on the inputs before the kernel is evaluated.
For example, the [`ScaleTransform`](@ref) multiplies every sample with a scalar. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a Squared Exponential Kernel with lengthscale `1/ρ`.
Here are some examples of how to use these transformations that are all equivalent:
```julia
k = transform(SqExponentialKernel(),ScaleTransform(2.0)) # returns a TransformedKernel
k = @kernel SqExponentialKernel() l=2.0 # Will be available soon
k = TransformedKernel(SqExponentialKernel(),ScaleTransform(2.0))
k = TransformedKernel(SqExponentialKernel(), ScaleTransform(2.0)) # Constructor
k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for the constructor
k = transform(SqExponentialKernel(), 2.0) # Syntactic sugar
k = @kernel SqExponentialKernel() l=2.0 # Convenience macro
```

Check the [`Transform`](@ref) page to see the other options.
To premultiply the kernel by a variance, you can use `*` or create a `ScaledKernel`
---
To pre-multiply the kernel by a variance parameter, you can use `*` or create a `ScaledKernel`

```julia
k = 3.0*SqExponentialKernel()
k = ScaledKernel(SqExponentialKernel(),3.0)
Expand All @@ -29,7 +36,7 @@ To compute the kernel function on two vectors you can call
k = SqExponentialKernel()
x1 = rand(3)
x2 = rand(3)
k(x1,x2)
k(x1, x2)
```

## Creating a kernel matrix
Expand All @@ -40,9 +47,8 @@ For example:
```julia
k = SqExponentialKernel()
A = rand(10,5)
kernelmatrix(k,A,obsdim=1) # Return a 10x10 matrix
kernelmatrix(k,A,obsdim=2) # Return a 5x5 matrix
k(A,obsdim=1) # Syntactic sugar
kernelmatrix(k, A, obsdim = 1) # Return a 10x10 matrix
kernelmatrix(k, A, obsdim = 2) # Return a 5x5 matrix
```

We also support specific kernel matrices outputs:
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!
export transform
export duplicate, set! # Helpers

export Kernel
export Kernel, BaseKernel, @kernel
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel
export CosineKernel
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
Expand Down Expand Up @@ -38,6 +38,7 @@ using SpecialFunctions: loggamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using MacroTools: @capture
using StatsBase

"""
Expand All @@ -61,6 +62,7 @@ end

include("kernels/transformedkernel.jl")
include("kernels/scaledkernel.jl")
include("kernels/kernel_macro.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
Expand Down
55 changes: 55 additions & 0 deletions src/kernels/kernel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
@kernel [variance *] kernel
@kernel [variance *] kernel l=Real/Vector
@kernel [variance *] kernel t=transform

The `@kernel` macro is an helping alias to the [`transform`](@ref) function.
The first argument should be a kernel multiplied (or not) by a scalar (variance of the kernel).
The second argument (optional) can be a keyword :
- `l=ρ` where `ρ` is a positive scalar or a vector of scalar
- `t=transform` where `transform` is a [`Transform`](@ref) object
One can also directly use a `Transform` object without a keyword.

# Examples
```jldoctest
julia> k = @kernel SqExponentialKernel() l=3.0
Squared Exponential Kernel
- Scale Transform (s = 3.0)

julia> k == transform(SqExponentialKernel(), ScaleTransform(3.0))
true

julia> k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LinearTransform(rand(4,3))
Sum of 2 kernels:
- (w = 1.0) Matern Kernel (ν = 3.0)
- (w = 1.0) Linear Kernel (c = 0.0)
- Linear transform (size(A) = (4, 3))

julia> k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LinearTransform(rand(4,3)))
true

julia> k = @kernel 4.0*ExponentiatedKernel() l=3.0
Exponentiated Kernel
- Scale Transform (s = 3.0)
- σ² = 4.0
julia> k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0)
true
```
"""
macro kernel(expr::Expr, arg = nothing)
@capture(expr, ((scale_ * k_) | (k_)))
if arg === nothing
t = nothing
else
if @capture(arg, ((l = val_) | (t = val_)))
t = val
else
return :(error("The additional argument of `@kernel` is incorrect"))
end
end
if scale === nothing
return :(transform($(esc(k)), $(esc(t))))
else
return :($(esc(scale)) * transform($(esc(k)), $(esc(t))))
end
end
14 changes: 11 additions & 3 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
transform::Tr
end

function TransformedKernel(k::TransformedKernel, t::Transform)
TransformedKernel(kernel(k), t ∘ k.transform)
end

(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))

# Optimizations for scale transforms of simple kernels to save allocations:
Expand Down Expand Up @@ -41,11 +45,15 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
"""
transform

transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t)
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)

transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))

transform(k::Kernel, ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))

transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))
transform(k::Kernel, ρ::AbstractMatrix) = TransformedKernel(k, LinearTransform(ρ))

transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))
transform(k::Kernel, ::Nothing) = k

kernel(κ) = κ.kernel

Expand Down
10 changes: 10 additions & 0 deletions test/kernels/kernel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testset "Kernel Macro" begin
@test (@kernel SqExponentialKernel()) isa SqExponentialKernel
@test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64}
@test (@kernel 3.0 * SqExponentialKernel() l = 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
# @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
@test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64}
# @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64}
@test (@kernel 3.0 * SqExponentialKernel() + 5.0 * Matern32Kernel() l = 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}}
@test_throws ErrorException (@kernel SqExponentialKernel() w = 2.0)
end
10 changes: 8 additions & 2 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
s = rand(rng)
v = rand(rng, 3)
k = SqExponentialKernel()
kt = TransformedKernel(k,ScaleTransform(s))
ktard = TransformedKernel(k,ARDTransform(v))
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
@test kt(v1, v2) == transform(k, s)(v1,v2)
@test kt(v1, v2) ≈ k(s * v1, s * v2) atol=1e-5
@test ktard(v1, v2) ≈ transform(k, ARDTransform(v))(v1, v2) atol=1e-5
@test ktard(v1, v2) == transform(k,v)(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)

@test transform(kt, s) isa TransformedKernel{SqExponentialKernel,ChainTransform{Array{ScaleTransform{Float64},1}}}

@test transform(k, s) isa TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}}
@test transform(k, v) isa TransformedKernel{SqExponentialKernel,ARDTransform{Array{Float64,1}}}
@test transform(k, rand(3, 2)) isa TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}}

@testset "kernelmatrix" begin
rng = MersenneTwister(123456)

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs
@testset "kernels" begin
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernel_macro.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "tensorproduct.jl"))
include(joinpath("kernels", "transformedkernel.jl"))
Expand Down