-
Notifications
You must be signed in to change notification settings - Fork 36
Add general TensorProduct kernel #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
822e29e
8e50390
e677fa1
0c79837
70e4691
ba4f5ca
e8ed332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
""" | ||
TensorProduct(kernels...) | ||
|
||
Create a tensor product of kernels. | ||
""" | ||
struct TensorProduct{K} <: Kernel | ||
kernels::K | ||
end | ||
|
||
function TensorProduct(kernel::Kernel, kernels::Kernel...) | ||
return TensorProduct((kernel, kernels...)) | ||
end | ||
|
||
Base.length(kernel::TensorProduct) = length(kernel.kernels) | ||
|
||
(kernel::TensorProduct)(x, y) = kappa(kernel, x, y) | ||
function kappa(kernel::TensorProduct, x, y) | ||
return prod(kappa(k, xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) | ||
end | ||
|
||
# TODO: General implementation of `kernelmatrix` and `kerneldiagmatrix` | ||
# Default implementation assumes 1D observations | ||
|
||
function kernelmatrix!( | ||
K::AbstractMatrix, | ||
kernel::TensorProduct, | ||
X::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
obsdim ∈ (1, 2) || "obsdim should be 1 or 2 (see docs of kernelmatrix))" | ||
|
||
featuredim = feature_dim(obsdim) | ||
if !check_dims(K, X, X, featuredim, obsdim) | ||
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))")) | ||
end | ||
|
||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
kernelmatrix!(K, kernel.kernels[1], selectdim(X, featuredim, 1)) | ||
for (k, Xi) in Iterators.drop(zip(kernel.kernels, eachslice(X; dims = featuredim)), 1) | ||
K .*= kernelmatrix(k, Xi) | ||
end | ||
|
||
return K | ||
end | ||
|
||
function kernelmatrix!( | ||
K::AbstractMatrix, | ||
kernel::TensorProduct, | ||
X::AbstractMatrix, | ||
Y::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
obsdim ∈ (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))") | ||
|
||
featuredim = feature_dim(obsdim) | ||
if !check_dims(K, X, Y, featuredim, obsdim) | ||
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))")) | ||
end | ||
|
||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
kernelmatrix!(K, kernel.kernels[1], selectdim(X, featuredim, 1), | ||
selectdim(Y, featuredim, 1)) | ||
for (k, Xi, Yi) in Iterators.drop(zip(kernel.kernels, | ||
eachslice(X; dims = featuredim), | ||
eachslice(Y; dims = featuredim)), 1) | ||
K .*= kernelmatrix(k, Xi, Yi) | ||
end | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return K | ||
end | ||
|
||
# mapreduce with multiple iterators requires Julia 1.2 or later. | ||
|
||
function kernelmatrix( | ||
kernel::TensorProduct, | ||
X::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
obsdim ∈ (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))") | ||
|
||
featuredim = feature_dim(obsdim) | ||
if !check_dims(X, X, featuredim, obsdim) | ||
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not within 92 char lim. Please wrap string over multiple lines There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I just copied this from |
||
end | ||
|
||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
return mapreduce((x, y) -> x .* y, | ||
zip(kernel.kernels, eachslice(X; dims = featuredim))) do (k, Xi) | ||
kernelmatrix(k, Xi) | ||
end | ||
end | ||
|
||
function kernelmatrix( | ||
kernel::TensorProduct, | ||
X::AbstractMatrix, | ||
Y::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
) | ||
@assert obsdim ∈ (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))") | ||
|
||
featuredim = feature_dim(obsdim) | ||
if !check_dims(X, Y, featuredim, obsdim) | ||
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))")) | ||
end | ||
|
||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
return mapreduce((x, y) -> x .* y, | ||
zip(kernel.kernels, | ||
eachslice(X; dims = featuredim), | ||
eachslice(Y; dims = featuredim))) do (k, Xi, Yi) | ||
kernelmatrix(k, Xi, Yi) | ||
end | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
function kerneldiagmatrix!( | ||
K::AbstractVector, | ||
kernel::TensorProduct, | ||
X::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
) | ||
obsdim ∈ (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))") | ||
if length(K) != size(X, obsdim) | ||
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line length |
||
end | ||
|
||
featuredim = feature_dim(obsdim) | ||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
kerneldiagmatrix!(K, kernel.kernels[1], selectdim(X, featuredim, 1)) | ||
for (k, Xi) in Iterators.drop(zip(kernel.kernels, eachslice(X; dims = featuredim)), 1) | ||
K .*= kerneldiagmatrix(k, Xi) | ||
end | ||
|
||
return K | ||
end | ||
|
||
function kerneldiagmatrix( | ||
kernel::TensorProduct, | ||
X::AbstractMatrix; | ||
obsdim::Int = defaultobs | ||
) | ||
obsdim ∈ (1,2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))") | ||
|
||
featuredim = feature_dim(obsdim) | ||
size(X, featuredim) == length(kernel) || | ||
error("number of kernels and groups of features are not consistent") | ||
|
||
return mapreduce((x, y) -> x .* y, | ||
zip(kernel.kernels, eachslice(X; dims = featuredim))) do (k, Xi) | ||
kerneldiagmatrix(k, Xi) | ||
end | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0) | ||
|
||
function printshifted(io::IO, kernel::TensorProduct, shift::Int) | ||
print(io, "Tensor product of ", length(kernel), " kernels:") | ||
for k in kernel.kernels | ||
print(io, "\n") | ||
for _ in 1:(shift + 1) | ||
print(io, "\t") | ||
end | ||
print(io, "- ") | ||
printshifted(io, k, shift + 2) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
@testset "tensorproduct" begin | ||
rng = MersenneTwister(123456) | ||
u1 = rand(rng, 10) | ||
u2 = rand(rng, 10) | ||
v1 = rand(rng, 5) | ||
v2 = rand(rng, 5) | ||
|
||
# kernels | ||
k1 = SqExponentialKernel() | ||
k2 = ExponentialKernel() | ||
kernel1 = TensorProduct(k1, k2) | ||
kernel2 = TensorProduct([k1, k2]) | ||
|
||
@test kernel1.kernels === (k1, k2) === TensorProduct((k1, k2)).kernels | ||
|
||
@testset "kappa" begin | ||
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2])) | ||
val = k1(x[1], y[1]) * k2(x[2], y[2]) | ||
|
||
@test kernel1(x, y) == kernel2(x, y) == val | ||
@test KernelFunctions.kappa(kernel1, x, y) == | ||
KernelFunctions.kappa(kernel2, x, y) == val | ||
end | ||
end | ||
|
||
@testset "kernelmatrix" begin | ||
X = rand(2, 10) | ||
Y = rand(2, 10) | ||
trueX = kernelmatrix(k1, X[1, :]) .* kernelmatrix(k2, X[2, :]) | ||
trueXY = kernelmatrix(k1, X[1, :], Y[1, :]) .* kernelmatrix(k2, X[2, :], Y[2, :]) | ||
tmp = Matrix{Float64}(undef, 10, 10) | ||
|
||
for kernel in (kernel1, kernel2) | ||
@test kernelmatrix(kernel, X) == trueX | ||
@test kernelmatrix(kernel, X'; obsdim = 1) == trueX | ||
|
||
@test kernelmatrix(kernel, X, Y) == trueXY | ||
@test kernelmatrix(kernel, X', Y'; obsdim = 1) == trueXY | ||
|
||
fill!(tmp, 0) | ||
kernelmatrix!(tmp, kernel, X) | ||
@test tmp == trueX | ||
|
||
fill!(tmp, 0) | ||
kernelmatrix!(tmp, kernel, X'; obsdim = 1) | ||
@test tmp == trueX | ||
|
||
fill!(tmp, 0) | ||
kernelmatrix!(tmp, kernel, X, Y) | ||
@test tmp == trueXY | ||
|
||
fill!(tmp, 0) | ||
kernelmatrix!(tmp, kernel, X', Y'; obsdim = 1) | ||
@test tmp == trueXY | ||
end | ||
end | ||
|
||
@testset "kerneldiagmatrix" begin | ||
X = rand(2, 10) | ||
trueval = ones(10) | ||
tmp = Vector{Float64}(undef, 10) | ||
|
||
for kernel in (kernel1, kernel2) | ||
@test kerneldiagmatrix(kernel, X) == trueval | ||
@test kerneldiagmatrix(kernel, X'; obsdim = 1) == trueval | ||
|
||
fill!(tmp, 0) | ||
kerneldiagmatrix!(tmp, kernel, X) | ||
@test tmp == trueval | ||
|
||
fill!(tmp, 0) | ||
kerneldiagmatrix!(tmp, kernel, X'; obsdim = 1) | ||
@test tmp == trueval | ||
end | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.