Skip to content

Commit a158aca

Browse files
authored
Update docstrings of linear and polynomial kernel and fix their constraints (#228)
* Update polynomial.jl * Update Project.toml * Fix typo * Update tests * Fix constraints and rename `d` to `degree` * Add constraints for `LinearKernel` * Update documentation * Fix typo
1 parent c8b0bbe commit a158aca

File tree

4 files changed

+76
-33
lines changed

4 files changed

+76
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.15"
3+
version = "0.8.16"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/src/kernels.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ The [`LinearKernel`](@ref) is defined as
178178
k(x,x';c) = \langle x,x'\rangle + c,
179179
```
180180

181-
where $c \in \mathbb{R}$.
181+
where $c \geq 0$.
182182

183183
### Polynomial Kernel
184184

@@ -188,8 +188,7 @@ The [`PolynomialKernel`](@ref) is defined as
188188
k(x,x';c,d) = \left(\langle x,x'\rangle + c\right)^d,
189189
```
190190

191-
where $c \in \mathbb{R}$ and $d>0$.
192-
191+
where $c \geq 0$ and $d \in \mathbb{N}$.
193192

194193
## Rational Quadratic
195194

src/basekernels/polynomial.jl

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
"""
2-
LinearKernel(; c = 0.0)
2+
LinearKernel(; c::Real=0.0)
33
4-
The linear kernel is a Mercer kernel given by
5-
```
6-
κ(x,y) = xᵀy + c
4+
Linear kernel with constant offset `c`.
5+
6+
# Definition
7+
8+
For inputs ``x, x' \\in \\mathbb{R}^d``, the linear kernel with constant offset
9+
``c \\geq 0`` is defined as
10+
```math
11+
k(x, x'; c) = x^\\top x' + c.
712
```
8-
Where `c` is a real number
13+
14+
See also: [`PolynomialKernel`](@ref)
915
"""
1016
struct LinearKernel{Tc<:Real} <: SimpleKernel
1117
c::Vector{Tc}
12-
function LinearKernel(; c::T=0.0) where {T}
13-
return new{T}([c])
18+
19+
function LinearKernel(; c::Real=0.0)
20+
@check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
21+
return new{typeof(c)}([c])
1422
end
1523
end
1624

@@ -23,29 +31,53 @@ metric(::LinearKernel) = DotProduct()
2331
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first.c), ")")
2432

2533
"""
26-
PolynomialKernel(; d = 2.0, c = 0.0)
34+
PolynomialKernel(; degree::Int=2, c::Real=0.0)
2735
28-
The polynomial kernel is a Mercer kernel given by
29-
```
30-
κ(x,y) = (xᵀy + c)^d
36+
Polynomial kernel of degree `degree` with constant offset `c`.
37+
38+
# Definition
39+
40+
For inputs ``x, x' \\in \\mathbb{R}^d``, the polynomial kernel of degree
41+
``\\nu \\in \\mathbb{N}`` with constant offset ``c \\geq 0`` is defined as
42+
```math
43+
k(x, x'; c, \\nu) = (x^\\top x' + c)^\\nu.
3144
```
32-
Where `c` is a real number, and `d` is a shape parameter bigger than 1. For `d = 1` see [`LinearKernel`](@ref)
45+
46+
See also: [`LinearKernel`](@ref)
3347
"""
34-
struct PolynomialKernel{Td<:Real,Tc<:Real} <: SimpleKernel
35-
d::Vector{Td}
48+
struct PolynomialKernel{Tc<:Real} <: SimpleKernel
49+
degree::Int
3650
c::Vector{Tc}
37-
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real,Tc<:Real}
38-
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
39-
return new{Td,Tc}([d], [c])
51+
52+
function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
53+
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
54+
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
55+
return new{Tc}(degree, c)
4056
end
4157
end
4258

43-
@functor PolynomialKernel
59+
function PolynomialKernel(; d::Real=-1, degree::Int=2, c::Real=0.0)
60+
if d != -1
61+
Base.depwarn(
62+
"keyword argument `d` is deprecated, use `degree` instead",
63+
:PiecewisePolynomialKernel,
64+
)
65+
isinteger(d) || error("polynomial degree has to be an integer")
66+
degree::Int = convert(Int, d)
67+
end
68+
return PolynomialKernel{typeof(c)}(degree, [c])
69+
end
70+
71+
# The degree of the polynomial kernel is a fixed discrete parameter
72+
function Functors.functor(::Type{<:PolynomialKernel}, x)
73+
reconstruct_polynomialkernel(xs) = PolynomialKernel{typeof(xs.c)}(x.degree, xs.c)
74+
return (c=x.c,), reconstruct_polynomialkernel
75+
end
4476

45-
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + first.c))^(first.d))
77+
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + first.c))^κ.degree
4678

4779
metric(::PolynomialKernel) = DotProduct()
4880

4981
function Base.show(io::IO, κ::PolynomialKernel)
50-
return print(io, "Polynomial Kernel (c = ", first.c), ", d = ", first.d), ")")
82+
return print(io, "Polynomial Kernel (c = ", first.c), ", degree = ", κ.degree, ")")
5183
end

test/basekernels/polynomial.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
x = rand(rng) * 2
44
v1 = rand(rng, 3)
55
v2 = rand(rng, 3)
6-
c = randn(rng)
6+
c = rand(rng)
77
@testset "LinearKernel" begin
88
k = LinearKernel()
99
@test kappa(k, x) x
1010
@test k(v1, v2) dot(v1, v2)
1111
@test kappa(LinearKernel(), x) == kappa(k, x)
1212
@test metric(LinearKernel()) == KernelFunctions.DotProduct()
13-
@test metric(LinearKernel(; c=2.0)) == KernelFunctions.DotProduct()
13+
@test metric(LinearKernel(; c=c)) == KernelFunctions.DotProduct()
1414
@test repr(k) == "Linear Kernel (c = 0.0)"
1515

16+
# Errors.
17+
@test_throws ArgumentError LinearKernel(; c=-0.5)
18+
1619
# Standardised tests.
1720
TestUtils.test_interface(k, Float64)
1821
test_ADs(x -> LinearKernel(; c=x[1]), [c])
@@ -23,18 +26,27 @@
2326
@test kappa(k, x) x^2
2427
@test k(v1, v2) dot(v1, v2)^2
2528
@test kappa(PolynomialKernel(), x) == kappa(k, x)
26-
@test repr(k) == "Polynomial Kernel (c = 0.0, d = 2.0)"
29+
@test repr(k) == "Polynomial Kernel (c = 0.0, degree = 2)"
2730

2831
# Coherence tests.
29-
@test kappa(PolynomialKernel(; d=1.0, c=c), x) kappa(LinearKernel(; c=c), x)
32+
@test kappa(PolynomialKernel(; degree=1, c=c), x) kappa(LinearKernel(; c=c), x)
3033
@test metric(PolynomialKernel()) == KernelFunctions.DotProduct()
31-
@test metric(PolynomialKernel(; d=3.0)) == KernelFunctions.DotProduct()
32-
@test metric(PolynomialKernel(; d=3.0, c=2.0)) == KernelFunctions.DotProduct()
34+
@test metric(PolynomialKernel(; degree=3)) == KernelFunctions.DotProduct()
35+
@test metric(PolynomialKernel(; degree=3, c=c)) == KernelFunctions.DotProduct()
36+
37+
# Deprecations.
38+
k = @test_deprecated PolynomialKernel(; d=1)
39+
@test k.degree == 1
40+
41+
# Errors.
42+
@test_throws ArgumentError PolynomialKernel(; d=0)
43+
@test_throws ArgumentError PolynomialKernel(; degree=0)
44+
@test_throws ArgumentError PolynomialKernel(; c=-0.5)
45+
@test_throws ErrorException PolynomialKernel(; d=2.5)
3346

3447
# Standardised tests.
3548
TestUtils.test_interface(k, Float64)
36-
# test_ADs(x->PolynomialKernel(d=x[1], c=x[2]),[2.0, c])
37-
@test_broken "All, because of the power"
38-
test_params(PolynomialKernel(; d=x, c=c), ([x], [c]))
49+
test_ADs(x -> PolynomialKernel(; c=x[1]), [c])
50+
test_params(PolynomialKernel(; c=c), ([c],))
3951
end
4052
end

0 commit comments

Comments
 (0)