Skip to content

Simplify implementation of TensorProduct #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ Base.length(k::KernelProduct) = length(k.kernels)

(κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)

hadamard(x, y) = x .* y

function kernelmatrix(κ::KernelProduct, x::AbstractVector)
return reduce(hadamard, kernelmatrix(κ.kernels[i], x) for i in 1:length(κ))
end
Expand Down
38 changes: 11 additions & 27 deletions src/kernels/tensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ function (kernel::TensorProduct)(x, y)
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

# TODO: General implementation of `kernelmatrix` and `kerneldiagmatrix`
# Default implementation assumes 1D observations

function validate_domain(k::TensorProduct, x::AbstractVector)
dim(x) == length(k) ||
error("number of kernels and groups of features are not consistent")
Expand Down Expand Up @@ -70,25 +67,6 @@ function kernelmatrix!(
return K
end

# mapreduce with multiple iterators requires Julia 1.2 or later.

function kernelmatrix(k::TensorProduct, x::AbstractVector)
validate_domain(k, x)

return mapreduce((x, y) -> x .* y, zip(k.kernels, slices(x))) do (k, xi)
kernelmatrix(k, xi)
end
end

function kernelmatrix(k::TensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, xi, yi)
kernelmatrix(k, xi, yi)
end
end

function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)
Expand All @@ -102,13 +80,19 @@ function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVecto
return K
end

function kerneldiagmatrix(k::TensorProduct, x::AbstractVector)
function kernelmatrix(k::TensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
end

kernels_and_inputs = zip(k.kernels, slices(x))
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, xi)
kerneldiagmatrix(k, xi)
end
function kernelmatrix(k::TensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
end

function kerneldiagmatrix(k::TensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
end

Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0)
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
hadamard(x, y) = x .* y

# Macro for checking arguments
macro check_args(K, param, cond, desc=string(cond))
quote
Expand Down