Skip to content

Commit b8a483c

Browse files
committed
Improved error throwing and tests
1 parent 052807f commit b8a483c

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

src/kernelmatrix.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ function kernelmatrix!(
1010
X::AbstractMatrix{T₂};
1111
obsdim::Int = defaultobs
1212
) where {T,T₁<:Real,T₂<:Real}
13-
@assert check_dims(K,X,X,feature_dim(obsdim),obsdim) "Dimensions of the target array are not consistent with X and Y"
13+
if !check_dims(K,X,X,feature_dim(obsdim),obsdim)
14+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
15+
end
1416
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
1517
end
1618

@@ -27,7 +29,9 @@ function kernelmatrix!(
2729
Y::AbstractMatrix{T₃};
2830
obsdim::Int = defaultobs
2931
) where {T,T₁,T₂,T₃}
30-
@assert check_dims(K,X,Y,feature_dim(obsdim),obsdim) "Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"
32+
if !check_dims(K,X,Y,feature_dim(obsdim),obsdim)
33+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
34+
end
3135
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
3236
end
3337

@@ -82,7 +86,9 @@ function kernelmatrix(
8286
Y::AbstractMatrix{T₂};
8387
obsdim=defaultobs
8488
) where {T,T₁<:Real,T₂<:Real}
85-
@assert check_dims(X,Y,feature_dim(obsdim),obsdim) "X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"
89+
if !check_dims(X,Y,feature_dim(obsdim),obsdim)
90+
throw(DimensionMismatch("X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"))
91+
end
8692
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
8793
return K
8894
end
@@ -113,6 +119,9 @@ function kerneldiagmatrix!(
113119
X::AbstractMatrix{T₂};
114120
obsdim::Int = defaultobs
115121
) where {T,T₁,T₂}
122+
if length(K) != size(X,obsdim)
123+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
124+
end
116125
if obsdim == 1
117126
for i in eachindex(K)
118127
@inbounds @views K[i] = kernel(κ, X[i,:],X[i,:])

src/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ end
1111

1212

1313
# Take highest Float among possibilities
14-
function promote_float(Tₖ::DataType...)
15-
if length(Tₖ) == 0
16-
return Float64
17-
end
18-
T = promote_type(Tₖ...)
19-
return T <: Real ? T : Float64
20-
end
14+
# function promote_float(Tₖ::DataType...)
15+
# if length(Tₖ) == 0
16+
# return Float64
17+
# end
18+
# T = promote_type(Tₖ...)
19+
# return T <: Real ? T : Float64
20+
# end
2121

2222
check_dims(K,X,Y,featdim,obsdim) = check_dims(X,Y,featdim,obsdim) && (size(K) == (size(X,obsdim),size(Y,obsdim)))
2323

test/test_kernelmatrix.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,29 @@ dims = [10,5]
66

77
A = rand(dims...)
88
B = rand(dims...)
9+
C = rand(8,9)
910
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
1011
Kdiag = [zeros(dims[1]),zeros(dims[2])]
1112
k = SqExponentialKernel()
12-
@testset "Inplace Kernel Matrix" begin
13-
for obsdim in [1,2]
14-
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
15-
@test kernelmatrix!(K[obsdim],k,A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
16-
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
13+
@testset "Kernel Matrix Operations" begin
14+
@testset "Inplace Kernel Matrix" begin
15+
for obsdim in [1,2]
16+
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
17+
@test kernelmatrix!(K[obsdim],k,A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
18+
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
19+
@test_throws DimensionMismatch kernelmatrix!(K[obsdim],k,A,C,obsdim=obsdim)
20+
@test_throws DimensionMismatch kernelmatrix!(K[obsdim],k,C,obsdim=obsdim)
21+
@test_throws DimensionMismatch kerneldiagmatrix!(Kdiag[obsdim],k,C,obsdim=obsdim)
22+
end
1723
end
18-
end
19-
20-
@testset "Kernel matrix" begin
21-
for obsdim in [1,2]
22-
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
23-
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
24-
@test k(A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
25-
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
26-
@test kernel(k,1.0,2.0) == kernel(k,[1.0],[2.0])
24+
@testset "Kernel matrix" begin
25+
for obsdim in [1,2]
26+
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
27+
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
28+
@test k(A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
29+
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
30+
@test kernel(k,1.0,2.0) == kernel(k,[1.0],[2.0])
31+
@test_throws DimensionMismatch kernelmatrix(k,A,C,obsdim=obsdim)
32+
end
2733
end
2834
end

0 commit comments

Comments
 (0)