Skip to content

Commit f5bcae4

Browse files
authored
Use Iterators.peel, reduce allocations, and add missing kernelmatrix_diag! definitions + tests (#379)
* Use `Iterators.peel` and reduce allocations * Add missing test for binary `kernelmatrix_diag!` * Add missing definition for `kernelmatrix_diag!` * Bump version * Fix some bugs * Add missing `validate_inplace_dims` method
1 parent 9f708d0 commit f5bcae4

File tree

5 files changed

+85
-21
lines changed

5 files changed

+85
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.23"
3+
version = "0.10.24"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/kernels/kerneltensorproduct.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVec
7171
validate_inplace_dims(K, x)
7272
validate_domain(k, x)
7373

74-
kernels_and_inputs = zip(k.kernels, slices(x))
75-
kernelmatrix!(K, first(kernels_and_inputs)...)
76-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
77-
K .*= kernelmatrix(k, xi)
74+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
75+
first_x, tail_x = Iterators.peel(slices(x))
76+
77+
# handle first kernel and input
78+
kernelmatrix!(K, first_kernels, first_x)
79+
80+
# handle remaining kernels and inputs
81+
Ktmp = similar(K)
82+
for (ki, xi) in zip(tail_kernels, tail_x)
83+
kernelmatrix!(Ktmp, ki, xi)
84+
hadamard!(K, K, Ktmp)
7885
end
7986

8087
return K
@@ -86,10 +93,18 @@ function kernelmatrix!(
8693
validate_inplace_dims(K, x, y)
8794
validate_domain(k, x)
8895

89-
kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
90-
kernelmatrix!(K, first(kernels_and_inputs)...)
91-
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
92-
K .*= kernelmatrix(k, xi, yi)
96+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
97+
first_x, tail_x = Iterators.peel(slices(x))
98+
first_y, tail_y = Iterators.peel(slices(y))
99+
100+
# handle first kernel and inputs
101+
kernelmatrix!(K, first_kernels, first_x, first_y)
102+
103+
# handle remaining kernels and inputs
104+
Ktmp = similar(K)
105+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
106+
kernelmatrix!(Ktmp, ki, xi, yi)
107+
hadamard!(K, K, Ktmp)
93108
end
94109

95110
return K
@@ -99,10 +114,40 @@ function kernelmatrix_diag!(K::AbstractVector, k::KernelTensorProduct, x::Abstra
99114
validate_inplace_dims(K, x)
100115
validate_domain(k, x)
101116

102-
kernels_and_inputs = zip(k.kernels, slices(x))
103-
kernelmatrix_diag!(K, first(kernels_and_inputs)...)
104-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
105-
K .*= kernelmatrix_diag(k, xi)
117+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
118+
first_x, tail_x = Iterators.peel(slices(x))
119+
120+
# handle first kernel and input
121+
kernelmatrix_diag!(K, first_kernels, first_x)
122+
123+
# handle remaining kernels and inputs
124+
Ktmp = similar(K)
125+
for (ki, xi) in zip(tail_kernels, tail_x)
126+
kernelmatrix_diag!(Ktmp, ki, xi)
127+
hadamard!(K, K, Ktmp)
128+
end
129+
130+
return K
131+
end
132+
133+
function kernelmatrix_diag!(
134+
K::AbstractVector, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
135+
)
136+
validate_inplace_dims(K, x, y)
137+
validate_domain(k, x)
138+
139+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
140+
first_x, tail_x = Iterators.peel(slices(x))
141+
first_y, tail_y = Iterators.peel(slices(y))
142+
143+
# handle first kernel and inputs
144+
kernelmatrix_diag!(K, first_kernels, first_x, first_y)
145+
146+
# handle remaining kernels and inputs
147+
Ktmp = similar(K)
148+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
149+
kernelmatrix_diag!(Ktmp, ki, xi, yi)
150+
hadamard!(K, K, Ktmp)
106151
end
107152

108153
return K

src/kernels/scaledkernel.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ function kernelmatrix_diag!(K::AbstractVector, κ::ScaledKernel, x::AbstractVect
6161
return K
6262
end
6363

64+
function kernelmatrix_diag!(
65+
K::AbstractVector, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
66+
)
67+
kernelmatrix_diag!(K, κ.kernel, x, y)
68+
K .*= κ.σ²
69+
return K
70+
end
71+
6472
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
6573

6674
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

src/test_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function test_interface(
8787

8888
tmp_diag = Vector{Float64}(undef, length(x0))
8989
@test kernelmatrix_diag!(tmp_diag, k, x0) kernelmatrix_diag(k, x0)
90+
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) kernelmatrix_diag(k, x0, x1)
9091
end
9192

9293
function test_interface(

src/utils.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,27 @@ function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::Abstract
197197
end
198198
end
199199

200-
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector)
201-
return validate_inplace_dims(K, x, x)
202-
end
203-
204-
function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
205-
if length(K) != length(x)
200+
function validate_inplace_dims(K::AbstractVector, x::AbstractVector, y::AbstractVector)
201+
validate_inputs(x, y)
202+
n = length(x)
203+
if length(y) != n
204+
throw(
205+
DimensionMismatch(
206+
"Length of input x ($n) not consistent with length of input y " *
207+
"($(length(y))",
208+
),
209+
)
210+
end
211+
if length(K) != n
206212
throw(
207213
DimensionMismatch(
208-
"Length of target vector K ($(length(K))) not consistent with length of input" *
209-
"vector x ($(length(x))",
214+
"Length of target vector K ($(length(K))) not consistent with length of " *
215+
"inputs ($n)",
210216
),
211217
)
212218
end
213219
end
220+
221+
function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector)
222+
return validate_inplace_dims(K, x, x)
223+
end

0 commit comments

Comments
 (0)