Skip to content

Commit 3bd5cf6

Browse files
authored
Merge pull request #54 from sharanry/maha
Add Mahalanobis distance-based kernel
2 parents 0654f2f + 5c7295e commit 3bd5cf6

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export ExponentiatedKernel
1414
export MaternKernel, Matern32Kernel, Matern52Kernel
1515
export LinearKernel, PolynomialKernel
1616
export RationalQuadraticKernel, GammaRationalQuadraticKernel
17+
export MahalanobisKernel
1718
export KernelSum, KernelProduct
1819
export TransformedKernel, ScaledKernel
1920

@@ -44,7 +45,7 @@ include("distances/dotproduct.jl")
4445
include("distances/delta.jl")
4546
include("transform/transform.jl")
4647

47-
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine"]
48+
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha"]
4849
include(joinpath("kernels",k*".jl"))
4950
end
5051
include("kernels/transformedkernel.jl")

src/kernels/maha.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
MahalanobisKernel(P::AbstractMatrix)
3+
4+
Mahalanobis distance-based kernel given by
5+
```math
6+
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
7+
```
8+
where the matrix P is the metric.
9+
10+
"""
11+
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: BaseKernel
12+
P::A
13+
function MahalanobisKernel(P::AbstractMatrix{T}) where {T<:Real}
14+
LinearAlgebra.checksquare(P)
15+
new{T,typeof(P)}(P)
16+
end
17+
end
18+
19+
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
20+
21+
metric::MahalanobisKernel) = SqMahalanobis.P)

src/trainable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ trainable(k::PolynomialKernel) = (k.d, k.c)
1616

1717
trainable(k::RationalQuadraticKernel) = (k.α,)
1818

19+
trainable(k::MahalanobisKernel) = (k.P,)
20+
1921
#### Composite kernels
2022

2123
trainable::KernelProduct) = κ.kernels

test/test_kernels.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
113113
@test kappa(PolynomialKernel(d=1.0,c=c),x) kappa(LinearKernel(c=c),x)
114114
end
115115
end
116+
@testset "Mahalanobis" begin
117+
P = rand(3,3)
118+
k = MahalanobisKernel(P)
119+
@test kappa(k,x) == exp(-x)
120+
@test k(v1,v2) exp(-sqmahalanobis(v1,v2, k.P))
121+
@test kappa(ExponentialKernel(),x) == kappa(k,x)
122+
end
116123
@testset "RationalQuadratic" begin
117124
@testset "RationalQuadraticKernel" begin
118125
α = 2.0

0 commit comments

Comments
 (0)