Skip to content

Commit 87be861

Browse files
committed
Applied formatting changes
1 parent f1d0433 commit 87be861

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
kernelmatrix!(K::AbstractMatrix, κ::Kernel, X, Y; obsdim::Integer = 2)
44
55
In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will be overwritten with the kernel matrix.
6+
Will return the computed matrix `K`
67
"""
78
kernelmatrix!
89

@@ -16,7 +17,7 @@ function kernelmatrix!(
1617
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
1718
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
1819
end
19-
map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, dims = obsdim))
20+
return map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, dims = obsdim))
2021
end
2122

2223
function kernelmatrix!(
@@ -25,7 +26,7 @@ function kernelmatrix!(
2526
X::AbstractMatrix;
2627
obsdim::Int = defaultobs
2728
)
28-
kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
29+
return kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
2930
end
3031

3132
function kernelmatrix!(
@@ -37,6 +38,7 @@ function kernelmatrix!(
3738
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
3839
end
3940
K .= κ.(X, X')
41+
return K
4042
end
4143

4244
function kernelmatrix!(
@@ -50,7 +52,7 @@ function kernelmatrix!(
5052
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
5153
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
5254
end
53-
map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, Y, dims = obsdim))
55+
return map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, Y, dims = obsdim))
5456
end
5557

5658
function kernelmatrix!(
@@ -60,8 +62,7 @@ function kernelmatrix!(
6062
Y::AbstractMatrix;
6163
obsdim::Int = defaultobs
6264
)
63-
kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
64-
65+
return kernelmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
6566
end
6667

6768
function kernelmatrix!(
@@ -74,6 +75,7 @@ function kernelmatrix!(
7475
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X)) and Y $(size(Y))"))
7576
end
7677
K .= κ.(X, Y')
78+
return K
7779
end
7880

7981
"""
@@ -86,21 +88,17 @@ Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
8688
"""
8789
kernelmatrix
8890

89-
function kernelmatrix::Kernel, X::AbstractVector)
90-
kernelmatrix(κ, X, X) #TODO Can be optimized later
91-
end
91+
kernelmatrix::Kernel, X::AbstractVector) = kernelmatrix(κ, X, X)
9292

93-
function kernelmatrix::Kernel, X::AbstractVector, Y::AbstractVector)
94-
κ.(X, Y')
95-
end
93+
kernelmatrix::Kernel, X::AbstractVector, Y::AbstractVector) = κ.(X, Y')
9694

9795
function kernelmatrix::SimpleKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
9896
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
99-
K = map(x -> kappa(κ, x), pairwise(metric(κ), X, dims = obsdim))
97+
return map(x -> kappa(κ, x), pairwise(metric(κ), X, dims = obsdim))
10098
end
10199

102100
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
103-
kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
101+
return kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
104102
end
105103

106104
function kernelmatrix(
@@ -113,11 +111,11 @@ function kernelmatrix(
113111
if !check_dims(X, Y, feature_dim(obsdim))
114112
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
115113
end
116-
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
114+
return map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
117115
end
118116

119117
function kernelmatrix::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs)
120-
kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
118+
return kernelmatrix(κ, vec_of_vecs(X, obsdim = obsdim), vec_of_vecs(Y, obsdim = obsdim))
121119
end
122120

123121
"""
@@ -134,12 +132,10 @@ function kerneldiagmatrix(
134132
X::AbstractMatrix;
135133
obsdim::Int = defaultobs
136134
)
137-
kerneldiagmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
135+
return kerneldiagmatrix(κ, vec_of_vecs(X, obsdim = obsdim))
138136
end
139137

140-
function kerneldiagmatrix::Kernel, X::AbstractVector)
141-
κ.(X, X)
142-
end
138+
kerneldiagmatrix::Kernel, X::AbstractVector) = κ.(X, X)
143139

144140
"""
145141
kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X; obsdim::Int = 2)
@@ -151,21 +147,20 @@ function kerneldiagmatrix!(
151147
κ::Kernel,
152148
X::AbstractMatrix;
153149
obsdim::Int = defaultobs
154-
)
150+
)
155151
if length(K) != size(X,obsdim)
156152
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
157153
end
158-
kerneldiagmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
159-
return K
154+
return kerneldiagmatrix!(K, κ, vec_of_vecs(X, obsdim = obsdim))
160155
end
161156

162157
function kerneldiagmatrix!(
163158
K::AbstractVector,
164159
κ::Kernel,
165160
X::AbstractVector
166-
)
161+
)
167162
if length(K) != length(X)
168163
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(length(X))"))
169164
end
170-
map!(κ, K, X, X)
165+
return map!(κ, K, X, X)
171166
end

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :))
6666
# end
6767

6868
function check_dims(K, X::AbstractVector, Y::AbstractVector)
69-
size(K) == (length(X), length(Y))
69+
return size(K) == (length(X), length(Y))
7070
end
7171

7272

7373
## Won't be needed with full ColVecs implementation
7474
function check_dims(K, X::AbstractMatrix, Y::AbstractMatrix, featdim, obsdim)
75-
check_dims(X, Y, featdim) &&
76-
(size(K) == (size(X, obsdim), size(Y, obsdim)))
75+
return check_dims(X, Y, featdim) &&
76+
(size(K) == (size(X, obsdim), size(Y, obsdim)))
7777
end
7878

7979
check_dims(X::AbstractMatrix, Y::AbstractMatrix, featdim) = size(X, featdim) == size(Y, featdim)

0 commit comments

Comments
 (0)