Skip to content

Commit bb3e859

Browse files
authored
Merge pull request #81 from devmotion/tensor
2 parents 18106f5 + e8ed332 commit bb3e859

File tree

5 files changed

+386
-6
lines changed

5 files changed

+386
-6
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export MahalanobisKernel, GaborKernel, PiecewisePolynomialKernel
1919
export PeriodicKernel
2020
export KernelSum, KernelProduct
2121
export TransformedKernel, ScaledKernel
22+
export TensorProduct
2223

2324
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
2425

@@ -57,6 +58,7 @@ include("kernels/scaledkernel.jl")
5758
include("matrix/kernelmatrix.jl")
5859
include("kernels/kernelsum.jl")
5960
include("kernels/kernelproduct.jl")
61+
include("kernels/tensorproduct.jl")
6062
include("approximations/nystrom.jl")
6163
include("generic.jl")
6264

src/kernels/tensorproduct.jl

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""
2+
TensorProduct(kernels...)
3+
4+
Create a tensor product kernel from kernels ``k_1, \\ldots, k_n``, i.e.,
5+
a kernel ``k`` that is given by
6+
```math
7+
k(x, y) = \\prod_{i=1}^n k_i(x_i, y_i).
8+
```
9+
10+
The `kernels` can be specified as individual arguments, a tuple, or an iterable data
11+
structure such as an array. Using a tuple or individual arguments guarantees that
12+
`TensorProduct` is concretely typed but might lead to large compilation times if the
13+
number of kernels is large.
14+
"""
15+
struct TensorProduct{K} <: Kernel
16+
kernels::K
17+
end
18+
19+
function TensorProduct(kernel::Kernel, kernels::Kernel...)
20+
return TensorProduct((kernel, kernels...))
21+
end
22+
23+
Base.length(kernel::TensorProduct) = length(kernel.kernels)
24+
25+
(kernel::TensorProduct)(x, y) = kappa(kernel, x, y)
26+
function kappa(kernel::TensorProduct, x, y)
27+
return prod(kappa(k, xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
28+
end
29+
30+
# TODO: General implementation of `kernelmatrix` and `kerneldiagmatrix`
31+
# Default implementation assumes 1D observations
32+
33+
function kernelmatrix!(
34+
K::AbstractMatrix,
35+
kernel::TensorProduct,
36+
X::AbstractMatrix;
37+
obsdim::Int = defaultobs,
38+
)
39+
obsdim (1, 2) || "obsdim should be 1 or 2 (see docs of kernelmatrix))"
40+
41+
featuredim = feature_dim(obsdim)
42+
if !check_dims(K, X, X, featuredim, obsdim)
43+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
44+
"consistent with X $(size(X))"))
45+
end
46+
47+
size(X, featuredim) == length(kernel) ||
48+
error("number of kernels and groups of features are not consistent")
49+
50+
kernels_and_inputs = zip(kernel.kernels, eachslice(X; dims = featuredim))
51+
kernelmatrix!(K, first(kernels_and_inputs)...)
52+
for (k, Xi) in Iterators.drop(kernels_and_inputs, 1)
53+
K .*= kernelmatrix(k, Xi)
54+
end
55+
56+
return K
57+
end
58+
59+
function kernelmatrix!(
60+
K::AbstractMatrix,
61+
kernel::TensorProduct,
62+
X::AbstractMatrix,
63+
Y::AbstractMatrix;
64+
obsdim::Int = defaultobs,
65+
)
66+
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
67+
68+
featuredim = feature_dim(obsdim)
69+
if !check_dims(K, X, Y, featuredim, obsdim)
70+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
71+
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
72+
end
73+
74+
size(X, featuredim) == length(kernel) ||
75+
error("number of kernels and groups of features are not consistent")
76+
77+
kernels_and_inputs = zip(
78+
kernel.kernels,
79+
eachslice(X; dims = featuredim),
80+
eachslice(Y; dims = featuredim),
81+
)
82+
kernelmatrix!(K, first(kernels_and_inputs)...)
83+
for (k, Xi, Yi) in Iterators.drop(kernels_and_inputs, 1)
84+
K .*= kernelmatrix(k, Xi, Yi)
85+
end
86+
87+
return K
88+
end
89+
90+
# mapreduce with multiple iterators requires Julia 1.2 or later.
91+
92+
function kernelmatrix(
93+
kernel::TensorProduct,
94+
X::AbstractMatrix;
95+
obsdim::Int = defaultobs,
96+
)
97+
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
98+
99+
featuredim = feature_dim(obsdim)
100+
if !check_dims(X, X, featuredim, obsdim)
101+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
102+
"consistent with X $(size(X))"))
103+
end
104+
105+
size(X, featuredim) == length(kernel) ||
106+
error("number of kernels and groups of features are not consistent")
107+
108+
return mapreduce((x, y) -> x .* y,
109+
zip(kernel.kernels, eachslice(X; dims = featuredim))) do (k, Xi)
110+
kernelmatrix(k, Xi)
111+
end
112+
end
113+
114+
function kernelmatrix(
115+
kernel::TensorProduct,
116+
X::AbstractMatrix,
117+
Y::AbstractMatrix;
118+
obsdim::Int = defaultobs
119+
)
120+
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
121+
122+
featuredim = feature_dim(obsdim)
123+
if !check_dims(X, Y, featuredim, obsdim)
124+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
125+
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
126+
end
127+
128+
size(X, featuredim) == length(kernel) ||
129+
error("number of kernels and groups of features are not consistent")
130+
131+
kernels_and_inputs = zip(
132+
kernel.kernels,
133+
eachslice(X; dims = featuredim),
134+
eachslice(Y; dims = featuredim),
135+
)
136+
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, Xi, Yi)
137+
kernelmatrix(k, Xi, Yi)
138+
end
139+
end
140+
141+
function kerneldiagmatrix!(
142+
K::AbstractVector,
143+
kernel::TensorProduct,
144+
X::AbstractMatrix;
145+
obsdim::Int = defaultobs
146+
)
147+
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
148+
if length(K) != size(X, obsdim)
149+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
150+
"consistent with X $(size(X))"))
151+
end
152+
153+
featuredim = feature_dim(obsdim)
154+
size(X, featuredim) == length(kernel) ||
155+
error("number of kernels and groups of features are not consistent")
156+
157+
kernels_and_inputs = zip(kernel.kernels, eachslice(X; dims = featuredim))
158+
kerneldiagmatrix!(K, first(kernels_and_inputs)...)
159+
for (k, Xi) in Iterators.drop(kernels_and_inputs, 1)
160+
K .*= kerneldiagmatrix(k, Xi)
161+
end
162+
163+
return K
164+
end
165+
166+
function kerneldiagmatrix(
167+
kernel::TensorProduct,
168+
X::AbstractMatrix;
169+
obsdim::Int = defaultobs
170+
)
171+
obsdim (1,2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
172+
173+
featuredim = feature_dim(obsdim)
174+
size(X, featuredim) == length(kernel) ||
175+
error("number of kernels and groups of features are not consistent")
176+
177+
kernels_and_inputs = zip(kernel.kernels, eachslice(X; dims = featuredim))
178+
return mapreduce((x, y) -> x .* y, kernels_and_inputs) do (k, Xi)
179+
kerneldiagmatrix(k, Xi)
180+
end
181+
end
182+
183+
Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0)
184+
185+
function printshifted(io::IO, kernel::TensorProduct, shift::Int)
186+
print(io, "Tensor product of ", length(kernel), " kernels:")
187+
for k in kernel.kernels
188+
print(io, "\n")
189+
for _ in 1:(shift + 1)
190+
print(io, "\t")
191+
end
192+
print(io, "- ")
193+
printshifted(io, k, shift + 2)
194+
end
195+
end

src/matrix/kernelmatrix.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
"""
2-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer = 2)
3-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer = 2)
2+
kernelmatrix!(K::AbstractMatrix, κ::Kernel, X; obsdim::Integer = 2)
3+
kernelmatrix!(K::AbstractMatrix, κ::Kernel, X, Y; obsdim::Integer = 2)
44
55
In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will be overwritten with the kernel matrix.
66
"""
77
kernelmatrix!
88

9+
function kernelmatrix!(
10+
K::AbstractMatrix,
11+
kernel::Kernel,
12+
X::AbstractVector;
13+
obsdim::Int = defaultobs,
14+
)
15+
return kernelmatrix!(K, kernel, reshape(X, 1, :); obsdim = 2)
16+
end
917

1018
function kernelmatrix!(
1119
K::AbstractMatrix,
@@ -23,6 +31,16 @@ end
2331
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
2432
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
2533

34+
function kernelmatrix!(
35+
K::AbstractMatrix,
36+
kernel::Kernel,
37+
X::AbstractVector,
38+
Y::AbstractVector;
39+
obsdim::Int = defaultobs,
40+
)
41+
return kernelmatrix!(K, kernel, reshape(X, 1, :), reshape(Y, 1, :); obsdim = 2)
42+
end
43+
2644
function kernelmatrix!(
2745
K::AbstractMatrix,
2846
κ::Kernel,
@@ -60,8 +78,8 @@ _kernel(κ::TransformedKernel, x::AbstractVector, y::AbstractVector; obsdim::Int
6078
_kernel(kernel(κ), apply.transform, x), apply.transform, y), obsdim = obsdim)
6179

6280
"""
63-
kernelmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
64-
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int = 2)
81+
kernelmatrix(κ::Kernel, X; obsdim::Int = 2)
82+
kernelmatrix(κ::Kernel, X, Y; obsdim::Int = 2)
6583
6684
Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
6785
`obsdim = 1` means the matrix `X` (and `Y`) has size #samples x #dimension
@@ -88,6 +106,15 @@ end
88106
kernelmatrix::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
89107
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
90108

109+
function kernelmatrix(
110+
kernel::Kernel,
111+
X::AbstractVector{<:Real},
112+
Y::AbstractVector{<:Real};
113+
obsdim::Int = defaultobs,
114+
)
115+
return kernelmatrix(kernel, reshape(X, 1, :), reshape(Y, 1, :); obsdim = 2)
116+
end
117+
91118
function kernelmatrix(
92119
κ::Kernel,
93120
X::AbstractMatrix,
@@ -107,12 +134,20 @@ kernelmatrix(κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim
107134
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
108135

109136
"""
110-
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
137+
kerneldiagmatrix(κ::Kernel, X; obsdim::Int = 2)
111138
112139
Calculate the diagonal matrix of `X` with respect to kernel `κ`
113140
`obsdim = 1` means the matrix `X` has size #samples x #dimension
114141
`obsdim = 2` means the matrix `X` has size #dimension x #samples
115142
"""
143+
function kerneldiagmatrix(
144+
kernel::Kernel,
145+
X::AbstractVector;
146+
obsdim::Int = defaultobs,
147+
)
148+
return kerneldiagmatrix(kernel, reshape(X, 1, :); obsdim = 2)
149+
end
150+
116151
function kerneldiagmatrix(
117152
κ::Kernel,
118153
X::AbstractMatrix;
@@ -127,10 +162,19 @@ function kerneldiagmatrix(
127162
end
128163

129164
"""
130-
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int = 2)
165+
kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X; obsdim::Int = 2)
131166
132167
In place version of [`kerneldiagmatrix`](@ref)
133168
"""
169+
function kerneldiagmatrix!(
170+
K::AbstractVector,
171+
kernel::Kernel,
172+
X::AbstractVector;
173+
obsdim::Int = defaultobs,
174+
)
175+
return kerneldiagmatrix!(K, kernel, reshape(X, 1, :); obsdim = 2)
176+
end
177+
134178
function kerneldiagmatrix!(
135179
K::AbstractVector,
136180
κ::Kernel,

0 commit comments

Comments
 (0)