Skip to content

Commit 8bb630f

Browse files
committed
Use pairwise! for inplace operations
1 parent 197a4d2 commit 8bb630f

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

src/basekernels/fbm.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ end
4545

4646
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector)
4747
modx = _mod(x)
48-
modxx = pairwise(SqEuclidean(sqroundoff), x)
49-
K .= _fbm.(modx, modx', modxx, κ.h)
48+
pairwise!(K, SqEuclidean(sqroundoff), x)
49+
K .= _fbm.(modx, modx', K, κ.h)
5050
return K
5151
end
5252

@@ -58,10 +58,10 @@ end
5858
function kernelmatrix!(
5959
K::AbstractMatrix,
6060
κ::FBMKernel,
61-
X::AbstractVector,
62-
Y::AbstractVector,
61+
x::AbstractVector,
62+
y::AbstractVector,
6363
)
64-
modxy = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
65-
K .= _fbm.(_mod(x), _mod(y)', modxy, κ.h)
64+
pairwise!(K, SqEuclidean(sqroundoff), x, y)
65+
K .= _fbm.(_mod(x), _mod(y)', K, κ.h)
6666
return K
6767
end

src/generic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,16 @@ function Distances.pairwise(
3838
)
3939
return pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
4040
end
41+
42+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
43+
return pairwise!(out, d, reshape(x, :, 1); dims=1)
44+
end
45+
46+
function Distances.pairwise!(
47+
out::AbstractMatrix,
48+
d::PreMetric,
49+
x::AbstractVector{<:Real},
50+
y::AbstractVector{<:Real},
51+
)
52+
return pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
53+
end

src/matrix/kernelmatrix.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x)
6969

7070
function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector)
7171
validate_inplace_dims(K, x)
72-
return map!(d -> kappa(κ, d), K, pairwise(metric(κ), x))
72+
pairwise!(K, metric(κ), x)
73+
return map!(d -> kappa(κ, d), K, K)
7374
end
7475

7576
function kernelmatrix!(
@@ -79,7 +80,8 @@ function kernelmatrix!(
7980
y::AbstractVector,
8081
)
8182
validate_inplace_dims(K, x, y)
82-
return map!(d -> kappa(κ, d), K, pairwise(metric(κ), x, y))
83+
pairwise!(K, metric(κ), x, y)
84+
return map!(d -> kappa(κ, d), K, K)
8385
end
8486

8587
function kernelmatrix::SimpleKernel, x::AbstractVector)

src/utils.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ dim(x::ColVecs) = size(x.X, 1)
4343

4444
Distances.pairwise(d::PreMetric, x::ColVecs) = pairwise(d, x.X; dims=2)
4545
Distances.pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = pairwise(d, x.X, y.X; dims=2)
46-
46+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
47+
return pairwise!(out, d, x.X; dims=2)
48+
end
49+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
50+
return pairwise!(out, d, x.X, y.X; dims=2)
51+
end
4752

4853
"""
4954
RowVecs(X::AbstractMatrix)
@@ -68,8 +73,12 @@ dim(x::RowVecs) = size(x.X, 2)
6873

6974
Distances.pairwise(d::PreMetric, x::RowVecs) = pairwise(d, x.X; dims=1)
7075
Distances.pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = pairwise(d, x.X, y.X; dims=1)
71-
72-
76+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
77+
return pairwise!(out, d, x.X; dims=1)
78+
end
79+
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
80+
return pairwise!(out, d, x.X, y.X; dims=1)
81+
end
7382

7483
"""
7584
Will be implemented at some point

test/basekernels/fbm.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
# kernelmatrix tests
88
m1 = rand(3,3)
99
m2 = rand(3,3)
10-
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
11-
10+
Kref = kernelmatrix(k, m1, m1)
11+
@test kernelmatrix(k, m1) Kref atol=1e-5
12+
K = zeros(3, 3)
13+
kernelmatrix!(K, k, m1, m1)
14+
@test K Kref atol=1e-5
15+
fill!(K, 0)
16+
kernelmatrix!(K, k, m1)
17+
@test K Kref atol=1e-5
1218

1319
x1 = rand()
1420
x2 = rand()

0 commit comments

Comments
 (0)