Skip to content

Commit d6b88a1

Browse files
committed
Removed the abstractvector{<:real} wrapper and adapted the rest of the functions
1 parent 5ddf568 commit d6b88a1

File tree

4 files changed

+12
-47
lines changed

4 files changed

+12
-47
lines changed

src/kernels/tensorproduct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ function kernelmatrix(
9696
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
9797

9898
featuredim = feature_dim(obsdim)
99-
if !check_dims(X, X, featuredim, obsdim)
99+
if !check_dims(X, X, featuredim)
100100
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
101101
"consistent with X $(size(X))"))
102102
end
@@ -119,7 +119,7 @@ function kernelmatrix(
119119
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
120120

121121
featuredim = feature_dim(obsdim)
122-
if !check_dims(X, Y, featuredim, obsdim)
122+
if !check_dims(X, Y, featuredim)
123123
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
124124
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
125125
end

src/matrix/kernelkroneckermat.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function kernelkronmat(
88
dims::Int
99
)
1010
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see `iskroncompatible()`)"
11-
k = kernelmatrix(κ,reshape(X,:,1),obsdim=1)
12-
kronecker(k,dims)
11+
k = kernelmatrix(κ, X)
12+
kronecker(k, dims)
1313
end
1414

1515
function kernelkronmat(
@@ -18,8 +18,8 @@ function kernelkronmat(
1818
obsdim::Int=defaultobs
1919
)
2020
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices"
21-
Ks = kernelmatrix.(κ,X,obsdim=obsdim)
22-
K = reduce(,Ks)
21+
Ks = kernelmatrix.(κ, X)
22+
K = reduce(, Ks)
2323
end
2424

2525

src/matrix/kernelmatrix.jl

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@ In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will b
66
"""
77
kernelmatrix!
88

9-
## Wrapper for vector of reals
10-
function kernelmatrix!(
11-
K::AbstractMatrix,
12-
κ::Kernel,
13-
X::AbstractVector{<:Real}
14-
)
15-
kernelmatrix!(K, κ, ColVecs(reshape(X, 1, :)))
16-
end
17-
189
function kernelmatrix!(
1910
K::AbstractMatrix,
2011
κ::SimpleKernel,
@@ -53,16 +44,6 @@ function kernelmatrix!(
5344
K .= κ.(X, X')
5445
end
5546

56-
## Wrapper for vector of reals
57-
function kernelmatrix!(
58-
K::AbstractMatrix,
59-
κ::Kernel,
60-
X::AbstractVector{<:Real},
61-
Y::AbstractVector{<:Real}
62-
)
63-
kernelmatrix!(K, κ, ColVecs(reshape(X, 1, :)), ColVecs(reshape(Y, 1, :)))
64-
end
65-
6647
function kernelmatrix!(
6748
K::AbstractMatrix,
6849
κ::SimpleKernel,
@@ -114,14 +95,6 @@ Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
11495
"""
11596
kernelmatrix
11697

117-
function kernelmatrix(
118-
κ::Kernel,
119-
X::AbstractVector{<:Real};
120-
obsdim::Int = defaultobs,
121-
)
122-
kernelmatrix(κ, reshape(X, 1, :), obsdim = 2)
123-
end
124-
12598
function kernelmatrix::Kernel, X::AbstractVector)
12699
kernelmatrix(κ, X, X) #TODO Can be optimized later
127100
end
@@ -144,22 +117,14 @@ function kernelmatrix(κ::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
144117
end
145118
end
146119

147-
function kernelmatrix(
148-
κ::Kernel,
149-
X::AbstractVector{<:Real},
150-
Y::AbstractVector{<:Real}
151-
)
152-
kernelmatrix(κ, ColVecs(reshape(X, 1, :)), ColVecs(reshape(Y, 1, :)))
153-
end
154-
155120
function kernelmatrix(
156121
κ::SimpleKernel,
157122
X::AbstractMatrix,
158123
Y::AbstractMatrix;
159124
obsdim = defaultobs,
160125
)
161126
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
162-
if !check_dims(X, Y, feature_dim(obsdim), obsdim)
127+
if !check_dims(X, Y, feature_dim(obsdim))
163128
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
164129
end
165130
_kernelmatrix(κ, X, Y, obsdim)

test/kernels/tensorproduct.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,27 @@
9595
trueXY = kernelmatrix(k1, X, Y)
9696
tmp = Matrix{Float64}(undef, 10, 10)
9797

98-
@test kernelmatrix(kernel, X) == trueX
98+
@test kernelmatrix(kernel, X) trueX
9999
@test kernelmatrix(kernel, X'; obsdim = 1) trueX
100100

101101
@test kernelmatrix(kernel, X, Y) trueXY
102102
@test kernelmatrix(kernel, X', Y'; obsdim = 1) trueXY
103103

104104
fill!(tmp, 0)
105105
kernelmatrix!(tmp, kernel, X)
106-
@test tmp == trueX
106+
@test tmp trueX
107107

108108
fill!(tmp, 0)
109109
kernelmatrix!(tmp, kernel, X'; obsdim = 1)
110-
@test tmp == trueX
110+
@test tmp trueX
111111

112112
fill!(tmp, 0)
113113
kernelmatrix!(tmp, kernel, X, Y)
114-
@test tmp == trueXY
114+
@test tmp trueXY
115115

116116
fill!(tmp, 0)
117117
kernelmatrix!(tmp, kernel, X', Y'; obsdim = 1)
118-
@test tmp == trueXY
118+
@test tmp trueXY
119119
end
120120

121121
@testset "kerneldiagmatrix" begin

0 commit comments

Comments
 (0)