Skip to content

Commit 87eeae7

Browse files
committed
Used vec_of_vecs for ambiguities
1 parent 1698e51 commit 87eeae7

File tree

1 file changed

+8
-34
lines changed

1 file changed

+8
-34
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@ function kernelmatrix!(
2626
obsdim::Int = defaultobs
2727
)
2828
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
29-
if obsdim == 1
30-
kernelmatrix!(K, κ, ColVecs(X'))
31-
else
32-
kernelmatrix!(K, κ, ColVecs(X))
33-
end
29+
kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
3430
end
3531

3632
function kernelmatrix!(
@@ -66,11 +62,8 @@ function kernelmatrix!(
6662
obsdim::Int = defaultobs
6763
)
6864
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
69-
if obsdim == 1
70-
kernelmatrix!(K, κ, ColVecs(X'), ColVecs(Y'))
71-
else
72-
kernelmatrix!(K, κ, ColVecs(X), ColVecs(Y))
73-
end
65+
kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
66+
7467
end
7568

7669
function kernelmatrix!(
@@ -110,11 +103,7 @@ end
110103

111104
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
112105
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
113-
if obsdim == 1
114-
kernelmatrix(κ, ColVecs(X'))
115-
else
116-
kernelmatrix(κ, ColVecs(X))
117-
end
106+
kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
118107
end
119108

120109
function kernelmatrix(
@@ -127,21 +116,14 @@ function kernelmatrix(
127116
if !check_dims(X, Y, feature_dim(obsdim))
128117
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
129118
end
130-
_kernelmatrix(κ, X, Y, obsdim)
119+
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
131120
end
132121

133122
function kernelmatrix::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs)
134123
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
135-
if obsdim == 1
136-
kernelmatrix(κ, ColVecs(X'), ColVecs(Y'))
137-
else
138-
kernelmatrix(κ, ColVecs(X), ColVecs(Y))
139-
end
124+
kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
140125
end
141126

142-
@inline _kernelmatrix::SimpleKernel, X, Y, obsdim) =
143-
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
144-
145127
"""
146128
kerneldiagmatrix(κ::Kernel, X; obsdim::Int = 2)
147129
@@ -157,11 +139,7 @@ function kerneldiagmatrix(
157139
obsdim::Int = defaultobs
158140
)
159141
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
160-
if obsdim == 1
161-
kerneldiagmatrix(κ, ColVecs(X'))
162-
else
163-
kerneldiagmatrix(κ, ColVecs(X))
164-
end
142+
kerneldiagmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
165143
end
166144

167145
function kerneldiagmatrix::Kernel, X::AbstractVector)
@@ -183,11 +161,7 @@ function kerneldiagmatrix!(
183161
if length(K) != size(X,obsdim)
184162
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
185163
end
186-
if obsdim == 1
187-
kerneldiagmatrix!(K, κ, ColVecs(X'))
188-
else
189-
kerneldiagmatrix!(K, κ, ColVecs(X))
190-
end
164+
kerneldiagmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
191165
return K
192166
end
193167

0 commit comments

Comments
 (0)