Skip to content

Commit 2e92ecd

Browse files
authored
Length check before applying TensorProduct (#111)
* Length check before applying TensorProduct * Change error to thrown to DimensionMismatch
1 parent 5f0cdc2 commit 2e92ecd

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

src/kernels/tensorproduct.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ end
2323
Base.length(kernel::TensorProduct) = length(kernel.kernels)
2424

2525
function (kernel::TensorProduct)(x, y)
26+
if !(length(x) == length(y) == length(kernel))
27+
throw(DimensionMismatch("number of kernels and number of features
28+
are not consistent"))
29+
end
2630
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
2731
end
2832

test/kernels/tensorproduct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
@test kernel1.kernels === (k1, k2) === TensorProduct((k1, k2)).kernels
1515
@test length(kernel1) == length(kernel2) == 2
16+
@test_throws DimensionMismatch kernel1(rand(3), rand(3))
1617

1718
@testset "val" begin
1819
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2]))

0 commit comments

Comments
 (0)