Skip to content

Commit b10684d

Browse files
authored
Merge pull request JuliaGaussianProcesses#102 from devmotion/pairwise!
2 parents adf55bf + f323f72 commit b10684d

File tree

6 files changed

+55
-13
lines changed

6 files changed

+55
-13
lines changed

src/basekernels/fbm.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ end
4545

4646
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector)
4747
modx = _mod(x)
48-
modxx = pairwise(SqEuclidean(sqroundoff), x)
49-
K .= _fbm.(modx, modx', modxx, κ.h)
48+
pairwise!(K, SqEuclidean(sqroundoff), x)
49+
K .= _fbm.(modx, modx', K, κ.h)
5050
return K
5151
end
5252

@@ -58,10 +58,10 @@ end
5858
function kernelmatrix!(
5959
K::AbstractMatrix,
6060
κ::FBMKernel,
61-
X::AbstractVector,
62-
Y::AbstractVector,
61+
x::AbstractVector,
62+
y::AbstractVector,
6363
)
64-
modxy = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
65-
K .= _fbm.(_mod(x), _mod(y)', modxy, κ.h)
64+
pairwise!(K, SqEuclidean(sqroundoff), x, y)
65+
K .= _fbm.(_mod(x), _mod(y)', K, κ.h)
6666
return K
6767
end

src/generic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,16 @@ function Distances.pairwise(
3333
)
3434
return pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
3535
end
36+
37+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
38+
return pairwise!(out, d, reshape(x, :, 1); dims=1)
39+
end
40+
41+
function Distances.pairwise!(
42+
out::AbstractMatrix,
43+
d::PreMetric,
44+
x::AbstractVector{<:Real},
45+
y::AbstractVector{<:Real},
46+
)
47+
return pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
48+
end

src/matrix/kernelmatrix.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x)
6969

7070
function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector)
7171
validate_inplace_dims(K, x)
72-
return map!(d -> kappa(κ, d), K, pairwise(metric(κ), x))
72+
pairwise!(K, metric(κ), x)
73+
return map!(d -> kappa(κ, d), K, K)
7374
end
7475

7576
function kernelmatrix!(
@@ -79,7 +80,8 @@ function kernelmatrix!(
7980
y::AbstractVector,
8081
)
8182
validate_inplace_dims(K, x, y)
82-
return map!(d -> kappa(κ, d), K, pairwise(metric(κ), x, y))
83+
pairwise!(K, metric(κ), x, y)
84+
return map!(d -> kappa(κ, d), K, K)
8385
end
8486

8587
function kernelmatrix::SimpleKernel, x::AbstractVector)

src/utils.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ dim(x::ColVecs) = size(x.X, 1)
4545

4646
Distances.pairwise(d::PreMetric, x::ColVecs) = pairwise(d, x.X; dims=2)
4747
Distances.pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = pairwise(d, x.X, y.X; dims=2)
48-
48+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
49+
return pairwise!(out, d, x.X; dims=2)
50+
end
51+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
52+
return pairwise!(out, d, x.X, y.X; dims=2)
53+
end
4954

5055
"""
5156
RowVecs(X::AbstractMatrix)
@@ -70,8 +75,12 @@ dim(x::RowVecs) = size(x.X, 2)
7075

7176
Distances.pairwise(d::PreMetric, x::RowVecs) = pairwise(d, x.X; dims=1)
7277
Distances.pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = pairwise(d, x.X, y.X; dims=1)
73-
74-
78+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
79+
return pairwise!(out, d, x.X; dims=1)
80+
end
81+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
82+
return pairwise!(out, d, x.X, y.X; dims=1)
83+
end
7584

7685
"""
7786
Will be implemented at some point

test/basekernels/fbm.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
# kernelmatrix tests
88
m1 = rand(3,3)
99
m2 = rand(3,3)
10-
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
11-
10+
Kref = kernelmatrix(k, m1, m1)
11+
@test kernelmatrix(k, m1) Kref atol=1e-5
12+
K = zeros(3, 3)
13+
kernelmatrix!(K, k, m1, m1)
14+
@test K Kref atol=1e-5
15+
fill!(K, 0)
16+
kernelmatrix!(K, k, m1)
17+
@test K Kref atol=1e-5
1218

1319
x1 = rand()
1420
x2 = rand()

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
DY = ColVecs(Y)
2525
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=2)
2626
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=2)
27+
K = zeros(N, N)
28+
pairwise!(K, SqEuclidean(), DX)
29+
@test K pairwise(SqEuclidean(), X; dims=2)
30+
K = zeros(N, N + 1)
31+
pairwise!(K, SqEuclidean(), DX, DY)
32+
@test K pairwise(SqEuclidean(), X, Y; dims=2)
2733

2834
let
2935
@test Zygote.pullback(ColVecs, X)[1] == DX
@@ -52,6 +58,12 @@
5258
DY = RowVecs(Y)
5359
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=1)
5460
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=1)
61+
K = zeros(D, D)
62+
pairwise!(K, SqEuclidean(), DX)
63+
@test K pairwise(SqEuclidean(), X; dims=1)
64+
K = zeros(D, D + 1)
65+
pairwise!(K, SqEuclidean(), DX, DY)
66+
@test K pairwise(SqEuclidean(), X, Y; dims=1)
5567

5668
let
5769
@test Zygote.pullback(RowVecs, X)[1] == DX

0 commit comments

Comments
 (0)