|
1 | 1 | @testset "maha" begin
|
2 | 2 | rng = MersenneTwister(123456)
|
3 |
| - x = 2 * rand(rng) |
4 | 3 | D_in = 3
|
5 | 4 | v1 = rand(rng, D_in)
|
6 | 5 | v2 = rand(rng, D_in)
|
7 | 6 |
|
8 |
| - |
9 | 7 | U = UpperTriangular(rand(rng, 3,3))
|
10 | 8 | P = Matrix(Cholesky(U, 'U', 0))
|
11 | 9 | @assert isposdef(P)
|
12 | 10 |
|
13 |
| - k = MahalanobisKernel(P=P) |
14 |
| - |
15 |
| - @test kappa(k, x) == exp(-x) |
| 11 | + k = @test_deprecated MahalanobisKernel(P=P) |
| 12 | + @test k isa TransformedKernel{SqExponentialKernel,<:LinearTransform} |
| 13 | + @test k.transform.A ≈ sqrt(2) .* U |
16 | 14 | @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P))
|
17 |
| - @test kappa(ExponentialKernel(), x) == kappa(k, x) |
18 |
| - @test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" |
19 |
| - |
20 |
| - M1, M2 = rand(rng,3,2), rand(rng,3,2) |
21 |
| - |
22 |
| - function FiniteDifferences.to_vec(dist::SqMahalanobis) |
23 |
| - return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) |
24 |
| - end |
25 |
| - a = rand() |
26 |
| - |
27 |
| - function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) |
28 |
| - return MahalanobisKernel(P=Array(U'*U))(v1, v2) |
29 |
| - end |
30 |
| - |
31 |
| - @test all(FiniteDifferences.j′vp(FDM, test_mahakernel, a, U, v1, v2)[1] .≈ |
32 |
| - UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1])) |
33 |
| - |
34 |
| - function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) |
35 |
| - return SqMahalanobis(Array(U'*U))(v1, v2) |
36 |
| - end |
37 |
| - |
38 |
| - @test all(FiniteDifferences.j′vp(FDM, test_sqmaha, a, U, v1, v2)[1] .≈ |
39 |
| - UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1])) |
40 |
| - |
41 |
| - # test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote]) |
42 |
| - @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" |
43 | 15 |
|
44 | 16 | # Standardised tests.
|
45 | 17 | @testset "ColVecs" begin
|
|
54 | 26 | x2 = RowVecs(randn(2, D_in))
|
55 | 27 | TestUtils.test_interface(k, x0, x1, x2)
|
56 | 28 | end
|
57 |
| - test_params(k, (P,)) |
58 | 29 | end
|
0 commit comments