Skip to content

Commit a92fcc2

Browse files
moletLetif Monesdevmotionwilltebbutt
authored
SelectTransform extended to Symbols (#155)
* SelectTransform extended to symbols * Update src/transform/selecttransform.jl Co-authored-by: David Widmann <[email protected]> * Specific constructor for AbstractVector{Int} removed * Update src/transform/selecttransform.jl Co-authored-by: David Widmann <[email protected]> * Update src/transform/selecttransform.jl Co-authored-by: willtebbutt <[email protected]> * docs fixed * test for kernelmatrix using SelectTransform with Symbols * Patch version bumped * AD tests for SelectTransform using Symbols * @test_broken added for ReverseDiff * print(" ") inserted after each file of the transform testset Co-authored-by: Letif Mones <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: willtebbutt <[email protected]>
1 parent f7eeefe commit a92fcc2

File tree

5 files changed

+104
-11
lines changed

5 files changed

+104
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.6.0"
3+
version = "0.6.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/transform/selecttransform.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
SelectTransform(dims::AbstractVector{Int})
2+
SelectTransform(dims)
33
44
Select the dimensions `dims` that the kernel is applied to.
55
```
@@ -9,17 +9,11 @@ Select the dimensions `dims` that the kernel is applied to.
99
transform(tr,X,obsdim=2) == X[dims,:]
1010
```
1111
"""
12-
struct SelectTransform{T<:AbstractVector{Int}} <: Transform
12+
struct SelectTransform{T} <: Transform
1313
select::T
14-
function SelectTransform{V}(dims::V) where {V<:AbstractVector{Int}}
15-
@assert all(dims .> 0) "Selective dimensions should all be positive integers"
16-
return new{V}(dims)
17-
end
1814
end
1915

20-
SelectTransform(x::T) where {T<:AbstractVector{Int}} = SelectTransform{T}(x)
21-
22-
set!(t::SelectTransform{<:AbstractVector{T}}, dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
16+
set!(t::SelectTransform, dims) = t.select .= dims
2317

2418
duplicate(t::SelectTransform,θ) = t
2519

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
23
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
34
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
45
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1415

1516
[compat]
17+
AxisArrays = "0.4.3"
1618
Distances = "0.9"
1719
FiniteDifferences = "0.10.8"
1820
Flux = "0.10, 0.11"

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using KernelFunctions
2+
using AxisArrays
23
using Distances
34
using Kronecker
45
using LinearAlgebra
@@ -57,12 +58,19 @@ using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs
5758

5859
@testset "transform" begin
5960
include(joinpath("transform", "transform.jl"))
61+
print(" ")
6062
include(joinpath("transform", "scaletransform.jl"))
63+
print(" ")
6164
include(joinpath("transform", "ardtransform.jl"))
65+
print(" ")
6266
include(joinpath("transform", "lineartransform.jl"))
67+
print(" ")
6368
include(joinpath("transform", "functiontransform.jl"))
69+
print(" ")
6470
include(joinpath("transform", "selecttransform.jl"))
71+
print(" ")
6572
include(joinpath("transform", "chaintransform.jl"))
73+
print(" ")
6674
end
6775
@info "Ran tests on Transform"
6876

test/transform/selecttransform.jl

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,105 @@
88
x_cols = ColVecs(randn(rng, maximum(select), 6))
99
x_rows = RowVecs(randn(rng, 4, maximum(select)))
1010

11-
@testset "$(typeof(x))" for x in [x_vecs, x_cols, x_rows]
11+
Xs = [x_vecs, x_cols, x_rows]
12+
13+
@testset "$(typeof(x))" for x in Xs
1214
x′ = map(t, x)
1315
@test all([t(x[n]) == x[n][select] for n in eachindex(x)])
1416
@test all([t(x[n]) == x′[n] for n in eachindex(x)])
1517
end
1618

19+
symbols = [:a, :b, :c, :d, :e]
20+
select_symbols = [:a, :c, :e]
21+
22+
ts = SelectTransform(select_symbols)
23+
24+
a_vecs = map(x->AxisArray(x, col=symbols), x_vecs)
25+
a_cols = ColVecs(AxisArray(x_cols.X, col=symbols, index=(1:6)))
26+
a_rows = RowVecs(AxisArray(x_rows.X, index=(1:4), col=symbols))
27+
28+
As = [a_vecs, a_cols, a_rows]
29+
30+
@testset "$(typeof(a))" for (a, x) in zip(As, Xs)
31+
a′ = map(ts, a)
32+
x′ = map(t, x)
33+
@test a′ == x′
34+
end
35+
1736
select2 = [2, 3, 5]
1837
KernelFunctions.set!(t, select2)
1938
@test t.select == select2
2039

40+
select_symbols2 = [:b, :c, :e]
41+
KernelFunctions.set!(ts, select_symbols2)
42+
@test ts.select == select_symbols2
43+
2144
@test repr(t) == "Select Transform (dims: $(select2))"
45+
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"
46+
2247
test_ADs(()->transform(SEKernel(), SelectTransform([1,2])))
48+
49+
X = randn(rng, (4, 3))
50+
A = AxisArray(X, row=[:a, :b, :c, :d], col=[:x, :y, :z])
51+
Y = randn(rng, (4, 2))
52+
B = AxisArray(Y, row=[:a, :b, :c, :d], col=[:v, :w])
53+
Z = randn(rng, (2, 3))
54+
C = AxisArray(Z, row=[:e, :f], col=[:x, :y, :z])
55+
56+
tx_row = transform(SEKernel(), SelectTransform([1,2,4]))
57+
ta_row = transform(SEKernel(), SelectTransform([:a,:b,:d]))
58+
tx_col = transform(SEKernel(), SelectTransform([1,3]))
59+
ta_col = transform(SEKernel(), SelectTransform([:x,:z]))
60+
61+
@test kernelmatrix(tx_row, X, obsdim=2) == kernelmatrix(ta_row, A, obsdim=2)
62+
@test kernelmatrix(tx_col, X, obsdim=1) == kernelmatrix(ta_col, A, obsdim=1)
63+
64+
@test kernelmatrix(tx_row, X, Y, obsdim=2) == kernelmatrix(ta_row, A, B, obsdim=2)
65+
@test kernelmatrix(tx_col, X, Z, obsdim=1) == kernelmatrix(ta_col, A, C, obsdim=1)
66+
67+
@testset "$(AD)" for AD in [:Zygote, :ForwardDiff]
68+
gx = gradient(AD, X) do x
69+
testfunction(tx_row, x, 2)
70+
end
71+
ga = gradient(AD, A) do a
72+
testfunction(ta_row, a, 2)
73+
end
74+
@test gx == ga
75+
gx = gradient(AD, X) do x
76+
testfunction(tx_col, x, 1)
77+
end
78+
ga = gradient(AD, A) do a
79+
testfunction(ta_col, a, 1)
80+
end
81+
@test gx == ga
82+
gx = gradient(AD, X) do x
83+
testfunction(tx_row, x, Y, 2)
84+
end
85+
ga = gradient(AD, A) do a
86+
testfunction(ta_row, a, B, 2)
87+
end
88+
@test gx == ga
89+
gx = gradient(AD, X) do x
90+
testfunction(tx_col, x, Z, 1)
91+
end
92+
ga = gradient(AD, A) do a
93+
testfunction(ta_col, a, C, 1)
94+
end
95+
@test gx == ga
96+
end
97+
98+
@testset "$(AD)" for AD in [:ReverseDiff]
99+
@test_broken ga = gradient(AD, A) do a
100+
testfunction(ta_row, a, 2)
101+
end
102+
@test_broken ga = gradient(AD, A) do a
103+
testfunction(ta_col, a, 1)
104+
end
105+
@test_broken ga = gradient(AD, A) do a
106+
testfunction(ta_row, a, B, 2)
107+
end
108+
@test_broken ga = gradient(AD, A) do a
109+
testfunction(ta_col, a, C, 1)
110+
end
111+
end
23112
end

0 commit comments

Comments
 (0)