Skip to content

Commit 51ab386

Browse files
committed
Simplify implementation of TensorProduct
1 parent 197a4d2 commit 51ab386

File tree

3 files changed

+15
-28
lines changed

3 files changed

+15
-28
lines changed

src/kernels/kernelproduct.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ Base.length(k::KernelProduct) = length(k.kernels)
2424

2525
::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
2626

27-
hadamard(x, y) = x .* y
28-
2927
function kernelmatrix::KernelProduct, x::AbstractVector)
3028
return reduce(hadamard, kernelmatrix.kernels[i], x) for i in 1:length(κ))
3129
end

src/kernels/tensorproduct.jl

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ function (kernel::TensorProduct)(x, y)
2626
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
2727
end
2828

29-
# TODO: General implementation of `kernelmatrix` and `kerneldiagmatrix`
30-
# Default implementation assumes 1D observations
31-
3229
function validate_domain(k::TensorProduct, x::AbstractVector)
3330
dim(x) == length(k) ||
3431
error("number of kernels and groups of features are not consistent")
@@ -70,25 +67,6 @@ function kernelmatrix!(
7067
return K
7168
end
7269

73-
# mapreduce with multiple iterators requires Julia 1.2 or later.
74-
75-
function kernelmatrix(k::TensorProduct, x::AbstractVector)
76-
validate_domain(k, x)
77-
78-
return mapreduce((x, y) -> x .* y, zip(k.kernels, slices(x))) do (k, xi)
79-
kernelmatrix(k, xi)
80-
end
81-
end
82-
83-
function kernelmatrix(k::TensorProduct, x::AbstractVector, y::AbstractVector)
84-
validate_domain(k, x)
85-
86-
kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
87-
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, xi, yi)
88-
kernelmatrix(k, xi, yi)
89-
end
90-
end
91-
9270
function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVector)
9371
validate_inplace_dims(K, x)
9472
validate_domain(k, x)
@@ -102,13 +80,22 @@ function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVecto
10280
return K
10381
end
10482

83+
function kernelmatrix(k::TensorProduct, x::AbstractVector)
84+
validate_domain(k, x)
85+
86+
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
87+
end
88+
89+
function kernelmatrix(k::TensorProduct, x::AbstractVector, y::AbstractVector)
90+
validate_domain(k, x)
91+
92+
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
93+
end
94+
10595
function kerneldiagmatrix(k::TensorProduct, x::AbstractVector)
10696
validate_domain(k, x)
10797

108-
kernels_and_inputs = zip(k.kernels, slices(x))
109-
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, xi)
110-
kerneldiagmatrix(k, xi)
111-
end
98+
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
11299
end
113100

114101
Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0)

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
hadamard(x, y) = x .* y
2+
13
# Macro for checking arguments
24
macro check_args(K, param, cond, desc=string(cond))
35
quote

0 commit comments

Comments
 (0)