Skip to content

Commit 197a4d2

Browse files
willtebbutttheogfdevmotion
authored
Move to AbstractVector (#99)
* Removes _kernel * Remove _scale * Removes binary kappa * Removes k(x, y) -> Matrix * Fixes tests and adds eval * Removes metric from Transformed * Refactors eval * Remove eval from MyKernel * Improves style src/basekernels/fbm.jl Co-Authored-By: Théo Galy-Fajou <[email protected]> * Reverts removal of _scale * Reverts change to transforms. Co-Authored-By: Théo Galy-Fajou <[email protected]> * Requires Julia 1.3 * Removes 1.0 CI * Reverts rationalquad.jl changes * Removes custom show * Adds reference to #96 * Moves abstract types around * Adds pairwise for RowVecs / ColVecs * Adds pairwise for AV{<:Real} * Refactors transforms * Refactors base kernels * Refactors composite kernels * Refactors kernelmatrix tests * Move tests around * Moves KernelProduct tests around * Final tweaks * More tests * Style fix Co-Authored-By: Théo Galy-Fajou <[email protected]> * Style fix Co-Authored-By: Théo Galy-Fajou <[email protected]> * Style fix Co-Authored-By: Théo Galy-Fajou <[email protected]> * Style fix Co-Authored-By: Théo Galy-Fajou <[email protected]> * Style fix Co-Authored-By: Théo Galy-Fajou <[email protected]> * Fix Identity transform * Fixes import bug * eachslice -> eachrow / eachcol Co-Authored-By: David Widmann <[email protected]> * Tweaks slice for AbstractVector{<:Real} Co-Authored-By: David Widmann <[email protected]> * Style Co-Authored-By: David Widmann <[email protected]> * Style Co-Authored-By: David Widmann <[email protected]> * Style Co-Authored-By: David Widmann <[email protected]> * Restore fallbacks * Tweaks comment * Extra tests Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 9207a96 commit 197a4d2

34 files changed

+830
-661
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ Requires = "1.0.1"
2020
SpecialFunctions = "0.8, 0.9, 0.10"
2121
StatsBase = "0.32, 0.33"
2222
StatsFuns = "0.8, 0.9"
23-
ZygoteRules = "0.2"
2423
Zygote = "= 0.4.16"
24+
ZygoteRules = "0.2"
2525
julia = "1.3"
2626

2727
[extras]

src/KernelFunctions.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ using StatsFuns: logtwo
3636
using InteractiveUtils: subtypes
3737
using StatsBase
3838

39-
const defaultobs = 2
40-
4139
"""
4240
Abstract type defining a slice-wise transformation on an input matrix
4341
"""

src/basekernels/fbm.jl

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,45 +33,35 @@ const sqroundoff = 1e-15
3333

3434
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h)/2
3535

36-
function kernelmatrix::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
37-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
38-
modX = sum(abs2, X; dims = feature_dim(obsdim))
39-
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
40-
return _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
36+
_mod(x::AbstractVector{<:Real}) = abs2.(x)
37+
_mod(x::ColVecs) = vec(sum(abs2, x.X; dims=1))
38+
_mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))
39+
40+
function kernelmatrix::FBMKernel, x::AbstractVector)
41+
modx = _mod(x)
42+
modxx = pairwise(SqEuclidean(sqroundoff), x)
43+
return _fbm.(modx, modx', modxx, κ.h)
4144
end
4245

43-
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
44-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
45-
modX = sum(abs2, X; dims = feature_dim(obsdim))
46-
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
47-
K .= _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
46+
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector)
47+
modx = _mod(x)
48+
modxx = pairwise(SqEuclidean(sqroundoff), x)
49+
K .= _fbm.(modx, modx', modxx, κ.h)
4850
return K
4951
end
5052

51-
function kernelmatrix(
52-
κ::FBMKernel,
53-
X::AbstractMatrix,
54-
Y::AbstractMatrix;
55-
obsdim::Int = defaultobs,
56-
)
57-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
58-
modX = sum(abs2, X, dims = feature_dim(obsdim))
59-
modY = sum(abs2, Y, dims = feature_dim(obsdim))
60-
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
61-
return _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
53+
function kernelmatrix::FBMKernel, x::AbstractVector, y::AbstractVector)
54+
modxy = pairwise(SqEuclidean(sqroundoff), x, y)
55+
return _fbm.(_mod(x), _mod(y)', modxy, κ.h)
6256
end
6357

6458
function kernelmatrix!(
6559
K::AbstractMatrix,
6660
κ::FBMKernel,
67-
X::AbstractMatrix,
68-
Y::AbstractMatrix;
69-
obsdim::Int = defaultobs,
61+
X::AbstractVector,
62+
Y::AbstractVector,
7063
)
71-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
72-
modX = sum(abs2, X, dims = feature_dim(obsdim))
73-
modY = sum(abs2, Y, dims = feature_dim(obsdim))
74-
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
75-
K .= _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
64+
modxy = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
65+
K .= _fbm.(_mod(x), _mod(y)', modxy, κ.h)
7666
return K
7767
end

src/generic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,16 @@ end
2525

2626
# Fallback implementation of evaluate for `SimpleKernel`s.
2727
(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y))
28+
29+
# This is type piracy. We should not doing this.
30+
function Distances.pairwise(d::PreMetric, x::AbstractVector{<:Real})
31+
return pairwise(d, reshape(x, :, 1); dims=1)
32+
end
33+
34+
function Distances.pairwise(
35+
d::PreMetric,
36+
x::AbstractVector{<:Real},
37+
y::AbstractVector{<:Real},
38+
)
39+
return pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
40+
end

src/kernels/kernelproduct.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,18 @@ Base.length(k::KernelProduct) = length(k.kernels)
2424

2525
::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
2626

27-
hadamard(x,y) = x.*y
27+
hadamard(x, y) = x .* y
2828

29-
function kernelmatrix(
30-
κ::KernelProduct,
31-
X::AbstractMatrix;
32-
obsdim::Int=defaultobs,
33-
)
34-
reduce(hadamard, kernelmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
29+
function kernelmatrix::KernelProduct, x::AbstractVector)
30+
return reduce(hadamard, kernelmatrix.kernels[i], x) for i in 1:length(κ))
3531
end
3632

37-
function kernelmatrix(
38-
κ::KernelProduct,
39-
X::AbstractMatrix,
40-
Y::AbstractMatrix;
41-
obsdim::Int=defaultobs,
42-
)
43-
reduce(hadamard, kernelmatrix.kernels[i], X, Y, obsdim = obsdim) for i in 1:length(κ))
33+
function kernelmatrix::KernelProduct, x::AbstractVector, y::AbstractVector)
34+
return reduce(hadamard, kernelmatrix.kernels[i], x, y) for i in 1:length(κ))
4435
end
4536

46-
function kerneldiagmatrix(
47-
κ::KernelProduct,
48-
X::AbstractMatrix;
49-
obsdim::Int=defaultobs,
50-
) #TODO Add test
51-
reduce(hadamard, kerneldiagmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
37+
function kerneldiagmatrix::KernelProduct, x::AbstractVector)
38+
return reduce(hadamard, kerneldiagmatrix.kernels[i], x) for i in 1:length(κ))
5239
end
5340

5441
function Base.show(io::IO, κ::KernelProduct)

src/kernels/kernelsum.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,16 @@ Base.length(k::KernelSum) = length(k.kernels)
4848

4949
::KernelSum)(x, y) = sum.weights[i] * κ.kernels[i](x, y) for i in 1:length(κ))
5050

51-
function kernelmatrix::KernelSum, X::AbstractMatrix; obsdim::Int = defaultobs)
52-
sum.weights[i] * kernelmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
51+
function kernelmatrix::KernelSum, x::AbstractVector)
52+
return sum.weights[i] * kernelmatrix.kernels[i], x) for i in 1:length(κ))
5353
end
5454

55-
function kernelmatrix(
56-
κ::KernelSum,
57-
X::AbstractMatrix,
58-
Y::AbstractMatrix;
59-
obsdim::Int = defaultobs,
60-
)
61-
sum.weights[i] * kernelmatrix.kernels[i], X, Y, obsdim = obsdim) for i in 1:length(κ))
55+
function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
56+
return sum.weights[i] * kernelmatrix.kernels[i], x, y) for i in 1:length(κ))
6257
end
6358

64-
function kerneldiagmatrix(
65-
κ::KernelSum,
66-
X::AbstractMatrix;
67-
obsdim::Int = defaultobs,
68-
)
69-
sum.weights[i] * kerneldiagmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
59+
function kerneldiagmatrix::KernelSum, x::AbstractVector)
60+
return sum.weights[i] * kerneldiagmatrix.kernels[i], x) for i in 1:length(κ))
7061
end
7162

7263
function Base.show(io::IO, κ::KernelSum)

src/kernels/scaledkernel.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,42 @@ function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real
1313
return ScaledKernel{Tk, Tσ²}(kernel, [σ²])
1414
end
1515

16-
kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)
17-
1816
(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y)
1917

20-
metric(k::ScaledKernel) = metric(k.kernel)
18+
function kernelmatrix::ScaledKernel, x::AbstractVector, y::AbstractVector)
19+
return κ.σ² .* kernelmatrix.kernel, x, y)
20+
end
21+
22+
function kernelmatrix::ScaledKernel, x::AbstractVector)
23+
return κ.σ² .* kernelmatrix.kernel, x)
24+
end
25+
26+
function kerneldiagmatrix::ScaledKernel, x::AbstractVector)
27+
return κ.σ² .* kerneldiagmatrix.kernel, x)
28+
end
29+
30+
function kernelmatrix!(
31+
K::AbstractMatrix,
32+
κ::ScaledKernel,
33+
x::AbstractVector,
34+
y::AbstractVector,
35+
)
36+
kernelmatrix!(K, κ, x, y)
37+
K .*= κ.σ²
38+
return K
39+
end
40+
41+
function kernelmatrix!(K::AbstractMatrix, κ::ScaledKernel, x::AbstractVector)
42+
kernelmatrix!(K, κ, x)
43+
K .*= κ.σ²
44+
return K
45+
end
46+
47+
function kerneldiagmatrix!(K::AbstractVector, κ::ScaledKernel, x::AbstractVector)
48+
kerneldiagmatrix!(K, κ, x)
49+
K .*= κ.σ²
50+
return K
51+
end
2152

2253
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
2354

0 commit comments

Comments
 (0)