Skip to content

Commit b2c2e52

Browse files
authored
Deprecate use of Mahalanobis distance (#225)
1 parent ee8d7ce commit b2c2e52

File tree

11 files changed

+133
-121
lines changed

11 files changed

+133
-121
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Format check
6363
run: |
6464
CHANGED="$(git diff --name-only)"
65-
if [ ! -z $CHANGED ]; then
65+
if [ ! -z "$CHANGED" ]; then
6666
>&2 echo "Some files have not been formatted !!!"
6767
echo "$CHANGED"
6868
exit 1

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.11"
3+
version = "0.8.12"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -17,7 +17,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1717
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1818

1919
[compat]
20-
Compat = "2.2, 3"
20+
Compat = "3.7"
2121
Distances = "0.9.1, 0.10"
2222
Functors = "0.1"
2323
Requires = "1.0.1"

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ NeuralNetworkKernel
3535
LinearKernel
3636
PolynomialKernel
3737
PiecewisePolynomialKernel
38-
MahalanobisKernel
3938
RationalQuadraticKernel
4039
GammaRationalQuadraticKernel
4140
spectral_mixture_kernel

docs/src/kernels.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,22 @@ where $r$ has the same dimension as $x$ and $r_i > 0$.
153153

154154
## Piecewise Polynomial Kernel
155155

156-
The [`PiecewisePolynomialKernel`](@ref) is defined for $x, x'\in \mathbb{R}^D$, a positive-definite matrix $P \in \mathbb{R}^{D \times D}$, and $V \in \{0,1,2,3\}$ as
156+
The [`PiecewisePolynomialKernel`](@ref) of degree $v \in \{0,1,2,3\}$ is defined for
157+
inputs $x, x' \in \mathbb{R}^d$ of dimension $d$ as
157158
```math
158-
k(x,x'; P, V) = \max(1 - \sqrt{x^\top P x'}, 0)^{j + V} f_V(\sqrt{x^\top P x'}, j),
159+
k(x, x'; v) = \max(1 - \|x - x'\|, 0)^{\alpha} f_{v,d}(\|x - x'\|),
159160
```
160-
where $j = \lfloor \frac{D}{2}\rfloor + V + 1$, and $f_V$ are polynomials defined as follows:
161+
where $\alpha = \lfloor \frac{d}{2}\rfloor + 2v + 1$, and $f_{v,d}$ are polynomials of
162+
degree $v$ given by
161163
```math
162164
\begin{aligned}
163-
f_0(r, j) &= 1, \\
164-
f_1(r, j) &= 1 + (j + 1) r, \\
165-
f_2(r, j) &= 1 + (j + 2) r + ((j^2 + 4j + 3) / 3) r^2, \\
166-
f_3(r, j) &= 1 + (j + 3) r + ((6 j^2 + 36j + 45) / 15) r^2 + ((j^3 + 9 j^2 + 23j + 15) / 15) r^3.
165+
f_{0,d}(r) &= 1, \\
166+
f_{1,d}(r) &= 1 + (j + 1) r, \\
167+
f_{2,d}(r) &= 1 + (j + 2) r + \big((j^2 + 4j + 3) / 3\big) r^2, \\
168+
f_{3,d}(r) &= 1 + (j + 3) r + \big((6 j^2 + 36j + 45) / 15\big) r^2 + \big((j^3 + 9 j^2 + 23j + 15) / 15\big) r^3,
167169
\end{aligned}
168170
```
171+
where $j = \lfloor \frac{d}{2}\rfloor + v + 1$.
169172

170173
## Polynomial Kernels
171174

docs/src/userguide.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@ For example, a squared exponential kernel is created by
2323
k = 3.0 * SqExponentialKernel()
2424
```
2525

26+
!!! tip "How do I use a Mahalanobis kernel?"
27+
The `MahalanobisKernel(; P=P)`, defined by
28+
```math
29+
k(x, x'; P) = \exp{\big(- (x - x')^\top P (x - x')\big)}
30+
```
31+
for a positive definite matrix $P = Q^\top Q$, is deprecated. Instead you can
32+
use a squared exponential kernel together with a [`LinearTransform`](@ref) of
33+
the inputs:
34+
```julia
35+
k = transform(SqExponentialKernel(), LinearTransform(sqrt(2) .* Q))
36+
```
37+
Analogously, you can combine other kernels such as the
38+
[`PiecewisePolynomialKernel`](@ref) with a [`LinearTransform`](@ref) of the
39+
inputs to obtain a kernel that is a function of the Mahalanobis distance
40+
between inputs.
41+
2642
## Using a kernel function
2743

2844
To evaluate the kernel function on two vectors you simply call the kernel object:

src/KernelFunctions.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export FBMKernel
3131
export MaternKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel
3232
export LinearKernel, PolynomialKernel
3333
export RationalQuadraticKernel, GammaRationalQuadraticKernel
34-
export MahalanobisKernel, GaborKernel, PiecewisePolynomialKernel
34+
export GaborKernel, PiecewisePolynomialKernel
3535
export PeriodicKernel, NeuralNetworkKernel
3636
export KernelSum, KernelProduct
3737
export TransformedKernel, ScaledKernel
@@ -90,7 +90,6 @@ include(joinpath("basekernels", "exponential.jl"))
9090
include(joinpath("basekernels", "exponentiated.jl"))
9191
include(joinpath("basekernels", "fbm.jl"))
9292
include(joinpath("basekernels", "gabor.jl"))
93-
include(joinpath("basekernels", "maha.jl"))
9493
include(joinpath("basekernels", "matern.jl"))
9594
include(joinpath("basekernels", "nn.jl"))
9695
include(joinpath("basekernels", "periodic.jl"))
@@ -118,6 +117,8 @@ include("zygote_adjoints.jl")
118117

119118
include("test_utils.jl")
120119

120+
include("deprecated.jl")
121+
121122
function __init__()
122123
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
123124
include(joinpath("matrix", "kernelkroneckermat.jl"))

src/basekernels/maha.jl

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 71 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,89 @@
1-
"""
2-
PiecewisePolynomialKernel{V}(maha::AbstractMatrix)
1+
@doc raw"""
2+
PiecewisePolynomialKernel(; degree::Int=0, dim::Int)
3+
PiecewisePolynomialKernel{degree}(dim::Int)
4+
5+
Piecewise polynomial kernel of degree `degree` for inputs of dimension `dim` with support in
6+
the unit ball.
7+
8+
# Definition
39
4-
Piecewise Polynomial covariance function with compact support, V = 0,1,2,3.
5-
The kernel functions are 2V times continuously differentiable and the corresponding
6-
processes are hence V times mean-square differentiable. The kernel function is:
10+
For inputs ``x, x' \in \mathbb{R}^d`` of dimension ``d``, the piecewise polynomial kernel
11+
of degree ``v \in \{0,1,2,3\}`` is defined as
712
```math
8-
κ(x, y) = max(1 - r, 0)^(j + V) * f(r, j) with j = floor(D / 2) + V + 1
13+
k(x, x'; v) = \max(1 - \|x - x'\|, 0)^{\alpha(v,d)} f_{v,d}(\|x - x'\|),
914
```
10-
where `r` is the Mahalanobis distance mahalanobis(x,y) with `maha` as the metric.
15+
where ``\alpha(v, d) = \lfloor \frac{d}{2}\rfloor + 2v + 1`` and ``f_{v,d}`` are
16+
polynomials of degree ``v`` given by
17+
```math
18+
\begin{aligned}
19+
f_{0,d}(r) &= 1, \\
20+
f_{1,d}(r) &= 1 + (j + 1) r, \\
21+
f_{2,d}(r) &= 1 + (j + 2) r + \big((j^2 + 4j + 3) / 3\big) r^2, \\
22+
f_{3,d}(r) &= 1 + (j + 3) r + \big((6 j^2 + 36j + 45) / 15\big) r^2 + \big((j^3 + 9 j^2 + 23j + 15) / 15\big) r^3,
23+
\end{aligned}
24+
```
25+
where ``j = \lfloor \frac{d}{2}\rfloor + v + 1``.
26+
27+
The kernel is ``2v`` times continuously differentiable and the corresponding Gaussian
28+
process is hence ``v`` times mean-square differentiable.
1129
"""
12-
struct PiecewisePolynomialKernel{V,A<:AbstractMatrix{<:Real}} <: SimpleKernel
13-
maha::A
14-
j::Int
15-
function PiecewisePolynomialKernel{V}(maha::AbstractMatrix{<:Real}) where {V}
16-
V in (0, 1, 2, 3) || error("Invalid parameter V=$(V). Should be 0, 1, 2 or 3.")
17-
LinearAlgebra.checksquare(maha)
18-
j = div(size(maha, 1), 2) + V + 1
19-
return new{V,typeof(maha)}(maha, j)
30+
struct PiecewisePolynomialKernel{D,C<:Tuple} <: SimpleKernel
31+
alpha::Int
32+
coeffs::C
33+
34+
function PiecewisePolynomialKernel{D}(dim::Int) where {D}
35+
dim > 0 || error("number of dimensions has to be positive")
36+
j = div(dim, 2) + D + 1
37+
alpha = j + D
38+
coeffs = piecewise_polynomial_coefficients(Val(D), j)
39+
return new{D,typeof(coeffs)}(alpha, coeffs)
2040
end
2141
end
2242

23-
function PiecewisePolynomialKernel(; v::Integer=0, maha::AbstractMatrix{<:Real})
24-
return PiecewisePolynomialKernel{v}(maha)
25-
end
43+
# TODO: remove `maha` keyword argument in next breaking release
44+
function PiecewisePolynomialKernel(; v::Int=-1, degree::Int=v, maha=nothing, dim::Int=-1)
45+
if v != -1
46+
Base.depwarn(
47+
"keyword argument `v` is deprecated, use `degree` instead",
48+
:PiecewisePolynomialKernel,
49+
)
50+
end
2651

27-
# Have to reconstruct the type parameter
28-
# See also https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663
29-
function Functors.functor(::Type{<:PiecewisePolynomialKernel{V}}, x) where {V}
30-
function reconstruct_kernel(xs)
31-
return PiecewisePolynomialKernel{V}(xs.maha)
52+
if maha !== nothing
53+
Base.depwarn(
54+
"keyword argument `maha` is deprecated, use a `LinearTransform` instead",
55+
:PiecewisePolynomialKernel,
56+
)
57+
dim = size(maha, 1)
58+
return transform(
59+
PiecewisePolynomialKernel{degree}(dim), LinearTransform(cholesky(maha).U)
60+
)
61+
else
62+
return PiecewisePolynomialKernel{degree}(dim)
3263
end
33-
return (maha=x.maha,), reconstruct_kernel
3464
end
3565

36-
_f::PiecewisePolynomialKernel{0}, r, j) = 1
37-
_f::PiecewisePolynomialKernel{1}, r, j) = 1 + (j + 1) * r
38-
_f::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r .^ 2
39-
function _f::PiecewisePolynomialKernel{3}, r, j)
40-
return 1 +
41-
(j + 3) * r +
42-
(6 * j^2 + 36j + 45) / 15 * r .^ 2 +
43-
(j^3 + 9 * j^2 + 23j + 15) / 15 * r .^ 3
66+
piecewise_polynomial_coefficients(::Val{0}, ::Int) = (1,)
67+
piecewise_polynomial_coefficients(::Val{1}, j::Int) = (1, j + 1)
68+
piecewise_polynomial_coefficients(::Val{2}, j::Int) = (1, j + 2, (j^2 + 4 * j)//3 + 1)
69+
function piecewise_polynomial_coefficients(::Val{3}, j::Int)
70+
return (1, j + 3, (2 * j^2 + 12 * j)//5 + 3, (j^3 + 9 * j^2 + 23 * j)//15 + 1)
4471
end
45-
46-
function kappa::PiecewisePolynomialKernel{V}, r) where {V}
47-
return max(1 - r, 0)^.j + V) * _f(κ, r, κ.j)
72+
function piecewise_polynomial_coefficients(::Val{D}, ::Int) where {D}
73+
return error("invalid degree $D, only 0, 1, 2, or 3 are supported")
4874
end
4975

50-
metric::PiecewisePolynomialKernel) = Mahalanobis.maha)
76+
kappa::PiecewisePolynomialKernel, r) = max(1 - r, 0)^κ.alpha * evalpoly(r, κ.coeffs)
77+
78+
metric(::PiecewisePolynomialKernel) = Euclidean()
5179

52-
function Base.show(io::IO, κ::PiecewisePolynomialKernel{V}) where {V}
80+
function Base.show(io::IO, κ::PiecewisePolynomialKernel{D}) where {D}
5381
return print(
54-
io, "Piecewise Polynomial Kernel (v = ", V, ", size(maha) = ", size.maha), ")"
82+
io,
83+
"Piecewise Polynomial Kernel (degree = ",
84+
D,
85+
", ⌊dim/2⌋ = ",
86+
κ.alpha - 1 - 2 * D,
87+
")",
5588
)
5689
end

src/deprecated.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# TODO: remove tests when removed
2+
@deprecate MahalanobisKernel(; P::AbstractMatrix{<:Real}) transform(
3+
SqExponentialKernel(), LinearTransform(sqrt(2) .* cholesky(P).U)
4+
)
5+
6+
# TODO: remove keyword argument `maha` when removed
7+
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
8+
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
9+
)

test/basekernels/maha.jl

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
@testset "maha" begin
22
rng = MersenneTwister(123456)
3-
x = 2 * rand(rng)
43
D_in = 3
54
v1 = rand(rng, D_in)
65
v2 = rand(rng, D_in)
@@ -9,40 +8,10 @@
98
P = Matrix(Cholesky(U, 'U', 0))
109
@assert isposdef(P)
1110

12-
k = MahalanobisKernel(; P=P)
13-
14-
@test kappa(k, x) == exp(-x)
11+
k = @test_deprecated MahalanobisKernel(; P=P)
12+
@test k isa TransformedKernel{SqExponentialKernel,<:LinearTransform}
13+
@test k.transform.A sqrt(2) .* U
1514
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
16-
@test kappa(ExponentialKernel(), x) == kappa(k, x)
17-
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
18-
19-
M1, M2 = rand(rng, 3, 2), rand(rng, 3, 2)
20-
21-
function FiniteDifferences.to_vec(dist::SqMahalanobis)
22-
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
23-
end
24-
a = rand()
25-
26-
function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
27-
return MahalanobisKernel(; P=Array(U' * U))(v1, v2)
28-
end
29-
30-
@test all(
31-
FiniteDifferences.j′vp(FDM, test_mahakernel, a, U, v1, v2)[1] .≈
32-
UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1]),
33-
)
34-
35-
function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
36-
return SqMahalanobis(Array(U' * U))(v1, v2)
37-
end
38-
39-
@test all(
40-
FiniteDifferences.j′vp(FDM, test_sqmaha, a, U, v1, v2)[1] .≈
41-
UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1]),
42-
)
43-
44-
# test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote])
45-
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
4615

4716
# Standardised tests.
4817
@testset "ColVecs" begin
@@ -57,5 +26,4 @@
5726
x2 = RowVecs(randn(2, D_in))
5827
TestUtils.test_interface(k, x0, x1, x2)
5928
end
60-
test_params(k, (P,))
6129
end

test/basekernels/piecewisepolynomial.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,32 @@
33
v1 = rand(D)
44
v2 = rand(D)
55
maha = Matrix{Float64}(I, D, D)
6-
v = 3
7-
k = PiecewisePolynomialKernel{v}(maha)
6+
degree = 3
87

9-
k2 = PiecewisePolynomialKernel(; v=v, maha=maha)
8+
k = PiecewisePolynomialKernel(; degree=degree, dim=D)
9+
k2 = PiecewisePolynomialKernel{degree}(D)
10+
k3 = @test_deprecated PiecewisePolynomialKernel{degree}(maha)
11+
k4 = @test_deprecated PiecewisePolynomialKernel(; degree=degree, maha=maha)
12+
k5 = @test_deprecated PiecewisePolynomialKernel(; v=degree, dim=D)
13+
k6 = @test_deprecated PiecewisePolynomialKernel(; v=degree, maha=maha)
1014

11-
@test k2(v1, v2) k(v1, v2) atol = 1e-5
15+
@test k2(v1, v2) == k(v1, v2)
16+
@test k3(v1, v2) k(v1, v2)
17+
@test k4(v1, v2) k(v1, v2)
18+
@test k5(v1, v2) k(v1, v2)
19+
@test k6(v1, v2) k(v1, v2)
1220

1321
@test_throws ErrorException PiecewisePolynomialKernel{4}(maha)
22+
@test_throws ErrorException PiecewisePolynomialKernel{4}(D)
23+
@test_throws ErrorException PiecewisePolynomialKernel{degree}(-1)
1424

15-
@test repr(k) == "Piecewise Polynomial Kernel (v = $(v), size(maha) = $(size(maha)))"
25+
@test repr(k) ==
26+
"Piecewise Polynomial Kernel (degree = $(degree), ⌊dim/2⌋ = $(div(D, 2)))"
1627

1728
# Standardised tests.
1829
TestUtils.test_interface(k, ColVecs{Float64}; dim_in=2)
1930
TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2)
20-
# test_ADs(maha-> PiecewisePolynomialKernel(v=2, maha = maha), maha)
21-
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
31+
test_ADs(() -> PiecewisePolynomialKernel{degree}(D))
2232

23-
test_params(k, (maha,))
33+
test_params(k, ())
2434
end

0 commit comments

Comments
 (0)