Skip to content

Commit 5dbf881

Browse files
committed
Merge branch 'master' into fix_docs
2 parents fb5ff94 + 86d430c commit 5dbf881

File tree

14 files changed

+118
-71
lines changed

14 files changed

+118
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.4.0"
3+
version = "0.4.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ abstract type BaseKernel <: Kernel end
5050
abstract type SimpleKernel <: BaseKernel end
5151

5252
include("utils.jl")
53+
include("distances/pairwise.jl")
5354
include("distances/dotproduct.jl")
5455
include("distances/delta.jl")
5556
include("distances/sinus.jl")

src/distances/pairwise.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Add our own pairwise function to be able to apply it on vectors
2+
3+
pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, Y')
4+
5+
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
6+
7+
function pairwise!(
8+
out::AbstractMatrix,
9+
d::PreMetric,
10+
X::AbstractVector,
11+
Y::AbstractVector,
12+
)
13+
broadcast!(d, out, X, Y')
14+
end
15+
16+
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)
17+
18+
function pairwise(d::PreMetric, x::AbstractVector{<:Real})
19+
return Distances.pairwise(d, reshape(x, :, 1); dims = 1)
20+
end
21+
22+
function pairwise(
23+
d::PreMetric,
24+
x::AbstractVector{<:Real},
25+
y::AbstractVector{<:Real},
26+
)
27+
return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims = 1)
28+
end
29+
30+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
31+
return Distances.pairwise!(out, d, reshape(x, :, 1); dims = 1)
32+
end
33+
34+
function pairwise!(
35+
out::AbstractMatrix,
36+
d::PreMetric,
37+
x::AbstractVector{<:Real},
38+
y::AbstractVector{<:Real},
39+
)
40+
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
41+
end

src/generic.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,3 @@ end
2020

2121
# Fallback implementation of evaluate for `SimpleKernel`s.
2222
(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y))
23-
24-
# This is type piracy. We should not doing this.
25-
function Distances.pairwise(d::PreMetric, x::AbstractVector{<:Real})
26-
return pairwise(d, reshape(x, :, 1); dims=1)
27-
end
28-
29-
function Distances.pairwise(
30-
d::PreMetric,
31-
x::AbstractVector{<:Real},
32-
y::AbstractVector{<:Real},
33-
)
34-
return pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
35-
end
36-
37-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
38-
return pairwise!(out, d, reshape(x, :, 1); dims=1)
39-
end
40-
41-
function Distances.pairwise!(
42-
out::AbstractMatrix,
43-
d::PreMetric,
44-
x::AbstractVector{<:Real},
45-
y::AbstractVector{<:Real},
46-
)
47-
return pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
48-
end

src/kernels/transformedkernel.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,25 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
3131

3232
"""
3333
```julia
34-
transform(k::BaseKernel, t::Transform) (1)
35-
transform(k::BaseKernel, ρ::Real) (2)
36-
transform(k::BaseKernel, ρ::AbstractVector) (3)
34+
transform(k::Kernel, t::Transform) (1)
35+
transform(k::Kernel, ρ::Real) (2)
36+
transform(k::Kernel, ρ::AbstractVector) (3)
3737
```
3838
(1) Create a TransformedKernel with transform `t` and kernel `k`
3939
(2) Same as (1) with a `ScaleTransform` with scale `ρ`
4040
(3) Same as (1) with an `ARDTransform` with scales `ρ`
4141
"""
4242
transform
4343

44-
transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t)
44+
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)
4545

46-
transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))
46+
function transform(k::TransformedKernel, t::Transform)
47+
return TransformedKernel(k.kernel, t k.transform)
48+
end
49+
50+
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ))
4751

48-
transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))
52+
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ))
4953

5054
kernel(κ) = κ.kernel
5155

src/utils.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
4343

4444
dim(x::ColVecs) = size(x.X, 1)
4545

46-
Distances.pairwise(d::PreMetric, x::ColVecs) = pairwise(d, x.X; dims=2)
47-
Distances.pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = pairwise(d, x.X, y.X; dims=2)
48-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
49-
return pairwise!(out, d, x.X; dims=2)
46+
pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2)
47+
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2)
48+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
49+
return Distances.pairwise!(out, d, x.X; dims=2)
5050
end
51-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
52-
return pairwise!(out, d, x.X, y.X; dims=2)
51+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
52+
return Distances.pairwise!(out, d, x.X, y.X; dims=2)
5353
end
5454

5555
"""
@@ -73,13 +73,13 @@ Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :))
7373

7474
dim(x::RowVecs) = size(x.X, 2)
7575

76-
Distances.pairwise(d::PreMetric, x::RowVecs) = pairwise(d, x.X; dims=1)
77-
Distances.pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = pairwise(d, x.X, y.X; dims=1)
78-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
79-
return pairwise!(out, d, x.X; dims=1)
76+
pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1)
77+
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1)
78+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
79+
return Distances.pairwise!(out, d, x.X; dims=1)
8080
end
81-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
82-
return pairwise!(out, d, x.X, y.X; dims=1)
81+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
82+
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
8383
end
8484

8585
"""

src/zygote_adjoints.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
end
66
end
77

8-
@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9-
D = pairwise(d, X, Y; dims = dims)
8+
@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9+
D = Distances.pairwise(d, X, Y; dims = dims)
1010
if dims == 1
1111
return D, Δ -> (nothing, nothing, nothing)
1212
else
1313
return D, Δ -> (nothing, nothing, nothing)
1414
end
1515
end
1616

17-
@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
18-
D = pairwise(d, X; dims = dims)
17+
@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2)
18+
D = Distances.pairwise(d, X; dims = dims)
1919
if dims == 1
2020
return D, Δ -> (nothing, nothing)
2121
else
@@ -30,17 +30,17 @@ end
3030
end
3131
end
3232

33-
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
34-
D = pairwise(d, X, Y; dims = dims)
33+
@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
34+
D = Distances.pairwise(d, X, Y; dims = dims)
3535
if dims == 1
3636
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
3737
else
3838
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
3939
end
4040
end
4141

42-
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
43-
D = pairwise(d, X; dims = dims)
42+
@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
43+
D = Distances.pairwise(d, X; dims = dims)
4444
if dims == 1
4545
return D, Δ -> (nothing, 2 * Δ * X)
4646
else

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Kronecker = "0.4"
2121
PDMats = "0.9, 0.10"
2222
ReverseDiff = "1.2"
2323
SpecialFunctions = "0.10"
24-
Zygote = "0.4"
24+
Zygote = "0.4, 0.5"

test/basekernels/gabor.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
@test k.ell 1.0 atol=1e-5
1818
@test k.p 1.0 atol=1e-5
1919
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
20-
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:ForwardDiff, :ReverseDiff])
20+
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
2121
@test_broken "Tests failing for Zygote on differentiating through ell and p"
22+
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
2223
end

test/distances/pairwise.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
@testset "pairwise" begin
2+
rng = MersenneTwister(123456)
3+
d = SqEuclidean()
4+
Ns = (4, 5)
5+
D = 3
6+
x = [randn(rng, D) for _ in 1:Ns[1]]
7+
y = [randn(rng, D) for _ in 1:Ns[2]]
8+
X = hcat(x...)
9+
Y = hcat(y...)
10+
K = zeros(Ns)
11+
12+
@test KernelFunctions.pairwise(d, x, y) pairwise(d, X, Y, dims=2)
13+
@test KernelFunctions.pairwise(d, x) pairwise(d, X, dims=2)
14+
KernelFunctions.pairwise!(K, d, x, y)
15+
@test K pairwise(d, X, Y, dims=2)
16+
K = zeros(Ns[1], Ns[1])
17+
KernelFunctions.pairwise!(K, d, x)
18+
@test K pairwise(d, X, dims=2)
19+
20+
x = randn(rng, 10)
21+
X = reshape(x, :, 1)
22+
y = randn(rng, 11)
23+
Y = reshape(y, :, 1)
24+
K = zeros(10, 11)
25+
@test KernelFunctions.pairwise(d, x, y) pairwise(d, X, Y; dims=1)
26+
@test KernelFunctions.pairwise(d, x) pairwise(d, X; dims=1)
27+
KernelFunctions.pairwise!(K, d, x, y)
28+
@test K pairwise(d, X, Y, dims=1)
29+
end

test/generic.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,4 @@
33
@test length(k) == 1
44
@test iterate(k) == (k,nothing)
55
@test iterate(k,1) == nothing
6-
7-
rng = MersenneTwister(123456)
8-
x = randn(rng, 10)
9-
X = reshape(x, :, 1)
10-
y = randn(rng, 11)
11-
Y = reshape(y, :, 1)
12-
@test pairwise(SqEuclidean(), x, y) pairwise(SqEuclidean(), X, Y; dims=1)
13-
@test pairwise(SqEuclidean(), x) pairwise(SqEuclidean(), X; dims=1)
146
end

test/kernels/transformedkernel.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
v2 = rand(rng, 3)
66

77
s = rand(rng)
8+
s2 = rand(rng)
89
v = rand(rng, 3)
910
k = SqExponentialKernel()
1011
kt = TransformedKernel(k,ScaleTransform(s))
@@ -15,6 +16,9 @@
1516
@test ktard(v1, v2) transform(k, ARDTransform(v))(v1, v2) atol=1e-5
1617
@test ktard(v1, v2) == transform(k,v)(v1, v2)
1718
@test ktard(v1, v2) == k(v .* v1, v .* v2)
19+
@test transform(kt, s2)(v1, v2) kt(s2 * v1, s2 * v2)
20+
@test KernelFunctions.kernel(kt) == k
21+
@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))
1822

1923
@testset "kernelmatrix" begin
2024
rng = MersenneTwister(123456)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs
4848
include("utils_AD.jl")
4949

5050
@testset "distances" begin
51+
include(joinpath("distances", "pairwise.jl"))
5152
include(joinpath("distances", "dotproduct.jl"))
5253
include(joinpath("distances", "delta.jl"))
5354
include(joinpath("distances", "sinus.jl"))

test/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
Y = randn(rng, D, N + 1)
2424
DY = ColVecs(Y)
25-
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=2)
26-
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=2)
25+
@test KernelFunctions.pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=2)
26+
@test KernelFunctions.pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=2)
2727
K = zeros(N, N)
28-
pairwise!(K, SqEuclidean(), DX)
28+
KernelFunctions.pairwise!(K, SqEuclidean(), DX)
2929
@test K pairwise(SqEuclidean(), X; dims=2)
3030
K = zeros(N, N + 1)
31-
pairwise!(K, SqEuclidean(), DX, DY)
31+
KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY)
3232
@test K pairwise(SqEuclidean(), X, Y; dims=2)
3333

3434
let
@@ -56,13 +56,13 @@
5656

5757
Y = randn(rng, D + 1, N)
5858
DY = RowVecs(Y)
59-
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=1)
60-
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=1)
59+
@test KernelFunctions.pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=1)
60+
@test KernelFunctions.pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=1)
6161
K = zeros(D, D)
62-
pairwise!(K, SqEuclidean(), DX)
62+
KernelFunctions.pairwise!(K, SqEuclidean(), DX)
6363
@test K pairwise(SqEuclidean(), X; dims=1)
6464
K = zeros(D, D + 1)
65-
pairwise!(K, SqEuclidean(), DX, DY)
65+
KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY)
6666
@test K pairwise(SqEuclidean(), X, Y; dims=1)
6767

6868
let

0 commit comments

Comments
 (0)