|
8 | 8 | x_cols = ColVecs(randn(rng, maximum(select), 6))
|
9 | 9 | x_rows = RowVecs(randn(rng, 4, maximum(select)))
|
10 | 10 |
|
11 |
| - @testset "$(typeof(x))" for x in [x_vecs, x_cols, x_rows] |
| 11 | + Xs = [x_vecs, x_cols, x_rows] |
| 12 | + |
| 13 | + @testset "$(typeof(x))" for x in Xs |
12 | 14 | x′ = map(t, x)
|
13 | 15 | @test all([t(x[n]) == x[n][select] for n in eachindex(x)])
|
14 | 16 | @test all([t(x[n]) == x′[n] for n in eachindex(x)])
|
15 | 17 | end
|
16 | 18 |
|
| 19 | + symbols = [:a, :b, :c, :d, :e] |
| 20 | + select_symbols = [:a, :c, :e] |
| 21 | + |
| 22 | + ts = SelectTransform(select_symbols) |
| 23 | + |
| 24 | + a_vecs = map(x->AxisArray(x, col=symbols), x_vecs) |
| 25 | + a_cols = ColVecs(AxisArray(x_cols.X, col=symbols, index=(1:6))) |
| 26 | + a_rows = RowVecs(AxisArray(x_rows.X, index=(1:4), col=symbols)) |
| 27 | + |
| 28 | + As = [a_vecs, a_cols, a_rows] |
| 29 | + |
| 30 | + @testset "$(typeof(a))" for (a, x) in zip(As, Xs) |
| 31 | + a′ = map(ts, a) |
| 32 | + x′ = map(t, x) |
| 33 | + @test a′ == x′ |
| 34 | + end |
| 35 | + |
17 | 36 | select2 = [2, 3, 5]
|
18 | 37 | KernelFunctions.set!(t, select2)
|
19 | 38 | @test t.select == select2
|
20 | 39 |
|
| 40 | + select_symbols2 = [:b, :c, :e] |
| 41 | + KernelFunctions.set!(ts, select_symbols2) |
| 42 | + @test ts.select == select_symbols2 |
| 43 | + |
21 | 44 | @test repr(t) == "Select Transform (dims: $(select2))"
|
| 45 | + @test repr(ts) == "Select Transform (dims: $(select_symbols2))" |
| 46 | + |
22 | 47 | test_ADs(()->transform(SEKernel(), SelectTransform([1,2])))
|
| 48 | + |
| 49 | + X = randn(rng, (4, 3)) |
| 50 | + A = AxisArray(X, row=[:a, :b, :c, :d], col=[:x, :y, :z]) |
| 51 | + Y = randn(rng, (4, 2)) |
| 52 | + B = AxisArray(Y, row=[:a, :b, :c, :d], col=[:v, :w]) |
| 53 | + Z = randn(rng, (2, 3)) |
| 54 | + C = AxisArray(Z, row=[:e, :f], col=[:x, :y, :z]) |
| 55 | + |
| 56 | + tx_row = transform(SEKernel(), SelectTransform([1,2,4])) |
| 57 | + ta_row = transform(SEKernel(), SelectTransform([:a,:b,:d])) |
| 58 | + tx_col = transform(SEKernel(), SelectTransform([1,3])) |
| 59 | + ta_col = transform(SEKernel(), SelectTransform([:x,:z])) |
| 60 | + |
| 61 | + @test kernelmatrix(tx_row, X, obsdim=2) == kernelmatrix(ta_row, A, obsdim=2) |
| 62 | + @test kernelmatrix(tx_col, X, obsdim=1) == kernelmatrix(ta_col, A, obsdim=1) |
| 63 | + |
| 64 | + @test kernelmatrix(tx_row, X, Y, obsdim=2) == kernelmatrix(ta_row, A, B, obsdim=2) |
| 65 | + @test kernelmatrix(tx_col, X, Z, obsdim=1) == kernelmatrix(ta_col, A, C, obsdim=1) |
| 66 | + |
| 67 | + @testset "$(AD)" for AD in [:Zygote, :ForwardDiff] |
| 68 | + gx = gradient(AD, X) do x |
| 69 | + testfunction(tx_row, x, 2) |
| 70 | + end |
| 71 | + ga = gradient(AD, A) do a |
| 72 | + testfunction(ta_row, a, 2) |
| 73 | + end |
| 74 | + @test gx == ga |
| 75 | + gx = gradient(AD, X) do x |
| 76 | + testfunction(tx_col, x, 1) |
| 77 | + end |
| 78 | + ga = gradient(AD, A) do a |
| 79 | + testfunction(ta_col, a, 1) |
| 80 | + end |
| 81 | + @test gx == ga |
| 82 | + gx = gradient(AD, X) do x |
| 83 | + testfunction(tx_row, x, Y, 2) |
| 84 | + end |
| 85 | + ga = gradient(AD, A) do a |
| 86 | + testfunction(ta_row, a, B, 2) |
| 87 | + end |
| 88 | + @test gx == ga |
| 89 | + gx = gradient(AD, X) do x |
| 90 | + testfunction(tx_col, x, Z, 1) |
| 91 | + end |
| 92 | + ga = gradient(AD, A) do a |
| 93 | + testfunction(ta_col, a, C, 1) |
| 94 | + end |
| 95 | + @test gx == ga |
| 96 | + end |
| 97 | + |
| 98 | + @testset "$(AD)" for AD in [:ReverseDiff] |
| 99 | + @test_broken ga = gradient(AD, A) do a |
| 100 | + testfunction(ta_row, a, 2) |
| 101 | + end |
| 102 | + @test_broken ga = gradient(AD, A) do a |
| 103 | + testfunction(ta_col, a, 1) |
| 104 | + end |
| 105 | + @test_broken ga = gradient(AD, A) do a |
| 106 | + testfunction(ta_row, a, B, 2) |
| 107 | + end |
| 108 | + @test_broken ga = gradient(AD, A) do a |
| 109 | + testfunction(ta_col, a, C, 1) |
| 110 | + end |
| 111 | + end |
23 | 112 | end
|
0 commit comments