Skip to content

Update docstrings of linear and polynomial kernel and fix their constraints #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.15"
version = "0.8.16"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
5 changes: 2 additions & 3 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ The [`LinearKernel`](@ref) is defined as
k(x,x';c) = \langle x,x'\rangle + c,
```

where $c \in \mathbb{R}$.
where $c \geq 0$.

### Polynomial Kernel

Expand All @@ -190,8 +190,7 @@ The [`PolynomialKernel`](@ref) is defined as
k(x,x';c,d) = \left(\langle x,x'\rangle + c\right)^d,
```

where $c \in \mathbb{R}$ and $d>0$.

where $c \geq 0$ and $d \in \mathbb{N}$.

## Rational Quadratic

Expand Down
72 changes: 52 additions & 20 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""
LinearKernel(; c = 0.0)
LinearKernel(; c::Real=0.0)

The linear kernel is a Mercer kernel given by
```
κ(x,y) = xᵀy + c
Linear kernel with constant offset `c`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the linear kernel with constant offset
``c \\geq 0`` is defined as
```math
k(x, x'; c) = x^\\top x' + c.
```
Where `c` is a real number

See also: [`PolynomialKernel`](@ref)
"""
struct LinearKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}
function LinearKernel(; c::T=0.0) where {T}
return new{T}([c])

function LinearKernel(; c::Real=0.0)
@check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
end
end

Expand All @@ -23,29 +31,53 @@ metric(::LinearKernel) = DotProduct()
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")")

"""
PolynomialKernel(; d = 2.0, c = 0.0)
PolynomialKernel(; degree::Int=2, c::Real=0.0)

The polynomial kernel is a Mercer kernel given by
```
κ(x,y) = (xᵀy + c)^d
Polynomial kernel of degree `degree` with constant offset `c`.

# Definition

For inputs ``x, x' \\in \\mathbb{R}^d``, the polynomial kernel of degree
``\\nu \\in \\mathbb{N}`` with constant offset ``c \\geq 0`` is defined as
```math
k(x, x'; c, \\nu) = (x^\\top x' + c)^\\nu.
```
Where `c` is a real number, and `d` is a shape parameter bigger than 1. For `d = 1` see [`LinearKernel`](@ref)

See also: [`LinearKernel`](@ref)
"""
struct PolynomialKernel{Td<:Real,Tc<:Real} <: SimpleKernel
d::Vector{Td}
struct PolynomialKernel{Tc<:Real} <: SimpleKernel
degree::Int
c::Vector{Tc}
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real,Tc<:Real}
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
return new{Td,Tc}([d], [c])

function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
return new{Tc}(degree, c)
end
end

@functor PolynomialKernel
function PolynomialKernel(; d::Real=-1, degree::Int=2, c::Real=0.0)
if d != -1
Base.depwarn(
"keyword argument `d` is deprecated, use `degree` instead",
:PiecewisePolynomialKernel,
)
isinteger(d) || error("polynomial degree has to be an integer")
degree::Int = convert(Int, d)
end
return PolynomialKernel{typeof(c)}(degree, [c])
end

# The degree of the polynomial kernel is a fixed discrete parameter
function Functors.functor(::Type{<:PolynomialKernel}, x)
reconstruct_polynomialkernel(xs) = PolynomialKernel{typeof(xs.c)}(x.degree, xs.c)
return (c=x.c,), reconstruct_polynomialkernel
end

kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^(first(κ.d))
kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree

metric(::PolynomialKernel) = DotProduct()

function Base.show(io::IO, κ::PolynomialKernel)
return print(io, "Polynomial Kernel (c = ", first(κ.c), ", d = ", first(κ.d), ")")
return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")")
end
30 changes: 21 additions & 9 deletions test/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
x = rand(rng) * 2
v1 = rand(rng, 3)
v2 = rand(rng, 3)
c = randn(rng)
c = rand(rng)
@testset "LinearKernel" begin
k = LinearKernel()
@test kappa(k, x) ≈ x
@test k(v1, v2) ≈ dot(v1, v2)
@test kappa(LinearKernel(), x) == kappa(k, x)
@test metric(LinearKernel()) == KernelFunctions.DotProduct()
@test metric(LinearKernel(; c=2.0)) == KernelFunctions.DotProduct()
@test metric(LinearKernel(; c=c)) == KernelFunctions.DotProduct()
@test repr(k) == "Linear Kernel (c = 0.0)"

# Errors.
@test_throws ArgumentError LinearKernel(; c=-0.5)

# Standardised tests.
TestUtils.test_interface(k, Float64)
test_ADs(x -> LinearKernel(; c=x[1]), [c])
Expand All @@ -23,18 +26,27 @@
@test kappa(k, x) ≈ x^2
@test k(v1, v2) ≈ dot(v1, v2)^2
@test kappa(PolynomialKernel(), x) == kappa(k, x)
@test repr(k) == "Polynomial Kernel (c = 0.0, d = 2.0)"
@test repr(k) == "Polynomial Kernel (c = 0.0, degree = 2)"

# Coherence tests.
@test kappa(PolynomialKernel(; d=1.0, c=c), x) ≈ kappa(LinearKernel(; c=c), x)
@test kappa(PolynomialKernel(; degree=1, c=c), x) ≈ kappa(LinearKernel(; c=c), x)
@test metric(PolynomialKernel()) == KernelFunctions.DotProduct()
@test metric(PolynomialKernel(; d=3.0)) == KernelFunctions.DotProduct()
@test metric(PolynomialKernel(; d=3.0, c=2.0)) == KernelFunctions.DotProduct()
@test metric(PolynomialKernel(; degree=3)) == KernelFunctions.DotProduct()
@test metric(PolynomialKernel(; degree=3, c=c)) == KernelFunctions.DotProduct()

# Deprecations.
k = @test_deprecated PolynomialKernel(; d=1)
@test k.degree == 1

# Errors.
@test_throws ArgumentError PolynomialKernel(; d=0)
@test_throws ArgumentError PolynomialKernel(; degree=0)
@test_throws ArgumentError PolynomialKernel(; c=-0.5)
@test_throws ErrorException PolynomialKernel(; d=2.5)

# Standardised tests.
TestUtils.test_interface(k, Float64)
# test_ADs(x->PolynomialKernel(d=x[1], c=x[2]),[2.0, c])
@test_broken "All, because of the power"
test_params(PolynomialKernel(; d=x, c=c), ([x], [c]))
test_ADs(x -> PolynomialKernel(; c=x[1]), [c])
test_params(PolynomialKernel(; c=c), ([c],))
end
end