Skip to content

SelectTransform extended to Symbols #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 21, 2020
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.6.0"
version = "0.6.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
12 changes: 3 additions & 9 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SelectTransform(dims::AbstractVector{Int})
SelectTransform(dims)

Select the dimensions `dims` that the kernel is applied to.
```
Expand All @@ -9,17 +9,11 @@ Select the dimensions `dims` that the kernel is applied to.
transform(tr,X,obsdim=2) == X[dims,:]
```
"""
struct SelectTransform{T<:AbstractVector{Int}} <: Transform
struct SelectTransform{T} <: Transform
select::T
function SelectTransform{V}(dims::V) where {V<:AbstractVector{Int}}
@assert all(dims .> 0) "Selective dimensions should all be positive integers"
return new{V}(dims)
end
end

SelectTransform(x::T) where {T<:AbstractVector{Int}} = SelectTransform{T}(x)

set!(t::SelectTransform{<:AbstractVector{T}}, dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
set!(t::SelectTransform, dims) = t.select .= dims

duplicate(t::SelectTransform,θ) = t

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AxisArrays = "0.4.3"
Distances = "0.9"
FiniteDifferences = "0.10.8"
Flux = "0.10, 0.11"
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using KernelFunctions
using AxisArrays
using Distances
using Kronecker
using LinearAlgebra
Expand Down Expand Up @@ -57,12 +58,19 @@ using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs

@testset "transform" begin
include(joinpath("transform", "transform.jl"))
print(" ")
include(joinpath("transform", "scaletransform.jl"))
print(" ")
include(joinpath("transform", "ardtransform.jl"))
print(" ")
include(joinpath("transform", "lineartransform.jl"))
print(" ")
include(joinpath("transform", "functiontransform.jl"))
print(" ")
include(joinpath("transform", "selecttransform.jl"))
print(" ")
include(joinpath("transform", "chaintransform.jl"))
print(" ")
end
@info "Ran tests on Transform"

Expand Down
91 changes: 90 additions & 1 deletion test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,105 @@
x_cols = ColVecs(randn(rng, maximum(select), 6))
x_rows = RowVecs(randn(rng, 4, maximum(select)))

@testset "$(typeof(x))" for x in [x_vecs, x_cols, x_rows]
Xs = [x_vecs, x_cols, x_rows]

@testset "$(typeof(x))" for x in Xs
x′ = map(t, x)
@test all([t(x[n]) == x[n][select] for n in eachindex(x)])
@test all([t(x[n]) == x′[n] for n in eachindex(x)])
end

symbols = [:a, :b, :c, :d, :e]
select_symbols = [:a, :c, :e]

ts = SelectTransform(select_symbols)

a_vecs = map(x->AxisArray(x, col=symbols), x_vecs)
a_cols = ColVecs(AxisArray(x_cols.X, col=symbols, index=(1:6)))
a_rows = RowVecs(AxisArray(x_rows.X, index=(1:4), col=symbols))

As = [a_vecs, a_cols, a_rows]

@testset "$(typeof(a))" for (a, x) in zip(As, Xs)
a′ = map(ts, a)
x′ = map(t, x)
@test a′ == x′
end

select2 = [2, 3, 5]
KernelFunctions.set!(t, select2)
@test t.select == select2

select_symbols2 = [:b, :c, :e]
KernelFunctions.set!(ts, select_symbols2)
@test ts.select == select_symbols2

@test repr(t) == "Select Transform (dims: $(select2))"
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"

test_ADs(()->transform(SEKernel(), SelectTransform([1,2])))

X = randn(rng, (4, 3))
A = AxisArray(X, row=[:a, :b, :c, :d], col=[:x, :y, :z])
Y = randn(rng, (4, 2))
B = AxisArray(Y, row=[:a, :b, :c, :d], col=[:v, :w])
Z = randn(rng, (2, 3))
C = AxisArray(Z, row=[:e, :f], col=[:x, :y, :z])

tx_row = transform(SEKernel(), SelectTransform([1,2,4]))
ta_row = transform(SEKernel(), SelectTransform([:a,:b,:d]))
tx_col = transform(SEKernel(), SelectTransform([1,3]))
ta_col = transform(SEKernel(), SelectTransform([:x,:z]))

@test kernelmatrix(tx_row, X, obsdim=2) == kernelmatrix(ta_row, A, obsdim=2)
@test kernelmatrix(tx_col, X, obsdim=1) == kernelmatrix(ta_col, A, obsdim=1)

@test kernelmatrix(tx_row, X, Y, obsdim=2) == kernelmatrix(ta_row, A, B, obsdim=2)
@test kernelmatrix(tx_col, X, Z, obsdim=1) == kernelmatrix(ta_col, A, C, obsdim=1)

@testset "$(AD)" for AD in [:Zygote, :ForwardDiff]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@molet Did the test_AD function not work for this code?

gx = gradient(AD, X) do x
testfunction(tx_row, x, 2)
end
ga = gradient(AD, A) do a
testfunction(ta_row, a, 2)
end
@test gx == ga
gx = gradient(AD, X) do x
testfunction(tx_col, x, 1)
end
ga = gradient(AD, A) do a
testfunction(ta_col, a, 1)
end
@test gx == ga
gx = gradient(AD, X) do x
testfunction(tx_row, x, Y, 2)
end
ga = gradient(AD, A) do a
testfunction(ta_row, a, B, 2)
end
@test gx == ga
gx = gradient(AD, X) do x
testfunction(tx_col, x, Z, 1)
end
ga = gradient(AD, A) do a
testfunction(ta_col, a, C, 1)
end
@test gx == ga
end

@testset "$(AD)" for AD in [:ReverseDiff]
@test_broken ga = gradient(AD, A) do a
testfunction(ta_row, a, 2)
end
@test_broken ga = gradient(AD, A) do a
testfunction(ta_col, a, 1)
end
@test_broken ga = gradient(AD, A) do a
testfunction(ta_row, a, B, 2)
end
@test_broken ga = gradient(AD, A) do a
testfunction(ta_col, a, C, 1)
end
end
end