Skip to content

Add metric field #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.9.7"
version = "0.10.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 0 additions & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ FBMKernel

```@docs
gaborkernel
GaborKernel
```

### Matérn Kernels
Expand Down
5 changes: 1 addition & 4 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module KernelFunctions

export kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag!
export transform
export duplicate, set! # Helpers

export Kernel, MOKernel
Expand All @@ -14,7 +13,7 @@ export FBMKernel
export MaternKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel
export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export GaborKernel, PiecewisePolynomialKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
Expand Down Expand Up @@ -112,8 +111,6 @@ include("zygoterules.jl")

include("test_utils.jl")

include("deprecations.jl")

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
include(joinpath("matrix", "kernelkroneckermat.jl"))
Expand Down
21 changes: 14 additions & 7 deletions src/basekernels/cosine.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
"""
CosineKernel()
CosineKernel(; metric=Euclidean())

Cosine kernel.
Cosine kernel with respect to the `metric`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the cosine kernel is defined as
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the cosine kernel is defined as
```math
k(x, x') = \\cos(\\pi \\|x-x'\\|_2).
k(x, x') = \\cos(\\pi d(x, x')).
```
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
"""
struct CosineKernel <: SimpleKernel end
struct CosineKernel{M} <: SimpleKernel
metric::M

function CosineKernel(; metric=Euclidean())
return new{typeof(metric)}(metric)
end
end

kappa(::CosineKernel, d::Real) = cospi(d)

metric(::CosineKernel) = Euclidean()
metric(k::CosineKernel) = k.metric

Base.show(io::IO, ::CosineKernel) = print(io, "Cosine Kernel")
Base.show(io::IO, k::CosineKernel) = print(io, "Cosine Kernel (metric = ", k.metric, ")")
94 changes: 53 additions & 41 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
"""
SqExponentialKernel()
SqExponentialKernel(; metric=Euclidean())

Squared exponential kernel.
Squared exponential kernel with respect to the `metric`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the squared exponential kernel is defined as
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the squared exponential kernel is
defined as
```math
k(x, x') = \\exp\\bigg(- \\frac{\\|x - x'\\|_2^2}{2}\\bigg).
k(x, x') = \\exp\\bigg(- \\frac{d(x, x')^2}{2}\\bigg).
```
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.

See also: [`GammaExponentialKernel`](@ref)
"""
struct SqExponentialKernel <: SimpleKernel end
struct SqExponentialKernel{M} <: SimpleKernel
metric::M

kappa(::SqExponentialKernel, d²::Real) = exp(-d² / 2)
function SqExponentialKernel(; metric=Euclidean())
return new{typeof(metric)}(metric)
end
end

kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2)
kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-d² / 2)

metric(::SqExponentialKernel) = SqEuclidean()
metric(k::SqExponentialKernel) = k.metric
metric(::SqExponentialKernel{<:Euclidean}) = SqEuclidean()

iskroncompatible(::SqExponentialKernel) = true

Base.show(io::IO, ::SqExponentialKernel) = print(io, "Squared Exponential Kernel")
function Base.show(io::IO, k::SqExponentialKernel)
return print(io, "Squared Exponential Kernel (metric = ", k.metric, ")")
end

## Aliases ##

Expand All @@ -46,28 +58,37 @@ Alias of [`SqExponentialKernel`](@ref).
const SEKernel = SqExponentialKernel

"""
ExponentialKernel()
ExponentialKernel(; metric=Euclidean())

Exponential kernel.
Exponential kernel with respect to the `metric`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the exponential kernel is defined as
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the exponential kernel is defined as
```math
k(x, x') = \\exp\\big(- \\|x - x'\\|_2\\big).
k(x, x') = \\exp\\big(- d(x, x')\\big).
```
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.

See also: [`GammaExponentialKernel`](@ref)
"""
struct ExponentialKernel <: SimpleKernel end
struct ExponentialKernel{M} <: SimpleKernel
metric::M

function ExponentialKernel(; metric=Euclidean())
return new{typeof(metric)}(metric)
end
end

kappa(::ExponentialKernel, d::Real) = exp(-d)

metric(::ExponentialKernel) = Euclidean()
metric(k::ExponentialKernel) = k.metric

iskroncompatible(::ExponentialKernel) = true

Base.show(io::IO, ::ExponentialKernel) = print(io, "Exponential Kernel")
function Base.show(io::IO, k::ExponentialKernel)
return print(io, "Exponential Kernel (metric = ", k.metric, ")")
end

## Aliases ##

Expand All @@ -86,53 +107,44 @@ Alias of [`ExponentialKernel`](@ref).
const Matern12Kernel = ExponentialKernel

"""
GammaExponentialKernel(; γ::Real=2.0)
GammaExponentialKernel(; γ::Real=1.0, metric=Euclidean())

γ-exponential kernel with parameter `γ`.
γ-exponential kernel with respect to the `metric` and with parameter `γ`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the γ-exponential kernel[^RW] with parameter
``\\gamma \\in (0, 2]`` is defined as
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the γ-exponential kernel[^RW] with
parameter ``\\gamma \\in (0, 2]``
is defined as
```math
k(x, x'; \\gamma) = \\exp\\big(- \\|x - x'\\|_2^{\\gamma}\\big).
k(x, x'; \\gamma) = \\exp\\big(- d(x, x')^{\\gamma}\\big).
```

!!! warning
The default value of parameter `γ` will be changed to `1.0` in the next breaking release
of KernelFunctions.
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.

See also: [`ExponentialKernel`](@ref), [`SqExponentialKernel`](@ref)

[^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning.
"""
struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel
γ::Vector{Tγ}
# function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma)
function GammaExponentialKernel(; gamma=nothing, γ=gamma)
γ2 = if γ === nothing
Base.depwarn(
"the default value of parameter `γ` of the `GammaExponentialKernel` will " *
"be changed to `1.0` in the next breaking release of KernelFunctions",
:GammaExponentialKernel,
)
2.0
else
γ
end
@check_args(GammaExponentialKernel, γ2, zero(γ2) < γ2 ≤ 2, "γ ∈ (0, 2]")
return new{typeof(γ2)}([γ2])
metric::M

function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean())
@check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
return new{typeof(γ),typeof(metric)}([γ], metric)
end
end

@functor GammaExponentialKernel

kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ))

metric(::GammaExponentialKernel) = Euclidean()
metric(k::GammaExponentialKernel) = k.metric

iskroncompatible(::GammaExponentialKernel) = true

function Base.show(io::IO, κ::GammaExponentialKernel)
return print(io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ")")
return print(
io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ", metric = ", κ.metric, ")"
)
end
76 changes: 0 additions & 76 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,79 +23,3 @@ function gaborkernel(;
return (SqExponentialKernel() ∘ sqexponential_transform) *
(CosineKernel() ∘ cosine_transform)
end

# everything below will be removed
"""
GaborKernel(; ell::Real=1.0, p::Real=1.0)

Gabor kernel with lengthscale `ell` and period `p`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the Gabor kernel with lengthscale ``l_i > 0``
and period ``p_i > 0`` is defined as
```math
k(x, x'; l, p) = \\exp\\bigg(- \\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{2l_i^2}\\bigg)
\\cos\\bigg(\\pi \\bigg(\\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{p_i^2} \\bigg)^{1/2}\\bigg).
```

!!! note
`GaborKernel` is deprecated and will be removed. Gabor kernels should be
constructed with [`gaborkernel`](@ref) instead.
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
Base.depwarn(
"`GaborKernel` is deprecated and will be removed. Gabor kernels should be " *
"constructed with `gaborkernel` instead.",
:GaborKernel,
)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform)
return new{typeof(k)}(k)
end
end

@functor GaborKernel

(κ::GaborKernel)(x, y) = κ.kernel(x, y)

_lengthscale_transform(::Nothing) = IdentityTransform()
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(x) = 1
_lengthscale(k::TransformedKernel) = _lengthscale(k.transform)
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
return _lengthscale(k.kernel.kernels[1])
elseif v == :p
return _lengthscale(k.kernel.kernels[2])
else
error("Invalid Property")
end
end

function Base.show(io::IO, κ::GaborKernel)
return print(io, "Gabor Kernel (ell = ", κ.ell, ", p = ", κ.p, ")")
end

kernelmatrix(κ::GaborKernel, x::AbstractVector) = kernelmatrix(κ.kernel, x)

function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, x, y)
end

kernelmatrix_diag(κ::GaborKernel, x::AbstractVector) = kernelmatrix_diag(κ.kernel, x)

function kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, x, y)
end
Loading