Skip to content

Commit 70e4691

Browse files
committed
Some fixes
1 parent 0c79837 commit 70e4691

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

src/kernels/tensorproduct.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ function kernelmatrix!(
4040

4141
featuredim = feature_dim(obsdim)
4242
if !check_dims(K, X, X, featuredim, obsdim)
43-
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
43+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
44+
"consistent with X $(size(X))"))
4445
end
4546

4647
size(X, featuredim) == length(kernel) ||
4748
error("number of kernels and groups of features are not consistent")
4849

49-
kernelmatrix!(K, kernel.kernels[1], selectdim(X, featuredim, 1))
50-
for (k, Xi) in Iterators.drop(zip(kernel.kernels, eachslice(X; dims = featuredim)), 1)
50+
kernels_and_input = zip(kernel.kernels, eachslice(X; dims = featuredim))
51+
kernelmatrix!(K, first(kernels_and_input)...)
52+
for (k, Xi) in Iterators.drop(kernels_and_input, 1)
5153
K .*= kernelmatrix(k, Xi)
5254
end
5355

@@ -65,14 +67,15 @@ function kernelmatrix!(
6567

6668
featuredim = feature_dim(obsdim)
6769
if !check_dims(K, X, Y, featuredim, obsdim)
68-
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
70+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
71+
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
6972
end
7073

7174
size(X, featuredim) == length(kernel) ||
7275
error("number of kernels and groups of features are not consistent")
7376

7477
kernels_and_input = zip(
75-
zip(kernel.kernels,
78+
kernel.kernels,
7679
eachslice(X; dims = featuredim),
7780
eachslice(Y; dims = featuredim),
7881
)
@@ -95,7 +98,8 @@ function kernelmatrix(
9598

9699
featuredim = feature_dim(obsdim)
97100
if !check_dims(X, X, featuredim, obsdim)
98-
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
101+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
102+
"consistent with X $(size(X))"))
99103
end
100104

101105
size(X, featuredim) == length(kernel) ||
@@ -113,11 +117,12 @@ function kernelmatrix(
113117
Y::AbstractMatrix;
114118
obsdim::Int = defaultobs
115119
)
116-
@assert obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
120+
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
117121

118122
featuredim = feature_dim(obsdim)
119123
if !check_dims(X, Y, featuredim, obsdim)
120-
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
124+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
125+
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
121126
end
122127

123128
size(X, featuredim) == length(kernel) ||
@@ -141,15 +146,17 @@ function kerneldiagmatrix!(
141146
)
142147
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
143148
if length(K) != size(X, obsdim)
144-
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
149+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
150+
"consistent with X $(size(X))"))
145151
end
146152

147153
featuredim = feature_dim(obsdim)
148154
size(X, featuredim) == length(kernel) ||
149155
error("number of kernels and groups of features are not consistent")
150156

151-
kerneldiagmatrix!(K, kernel.kernels[1], selectdim(X, featuredim, 1))
152-
for (k, Xi) in Iterators.drop(zip(kernel.kernels, eachslice(X; dims = featuredim)), 1)
157+
kernels_and_input = zip(kernel.kernels, eachslice(X; dims = featuredim))
158+
kerneldiagmatrix!(K, first(kernels_and_input)...)
159+
for (k, Xi) in Iterators.drop(kernels_and_input, 1)
153160
K .*= kerneldiagmatrix(k, Xi)
154161
end
155162

0 commit comments

Comments
 (0)