Skip to content

Commit d52def6

Browse files
authored
Merge pull request #89 from theogf/VecVecs
Added VecOfVecs and RowVecs
2 parents c09ec42 + 1bd1b26 commit d52def6

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

src/utils.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,20 @@ macro check_args(K, param, cond, desc=string(cond))
99
end
1010
end
1111

12+
function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
13+
@assert obsdim (1, 2) "obsdim should be 1 or 2"
14+
if obsdim == 1
15+
RowVecs(X)
16+
else
17+
ColVecs(X)
18+
end
19+
end
1220

1321
"""
1422
ColVecs(X::AbstractMatrix)
1523
16-
A lightweight box for an `AbstractMatrix` to make it behave like a vector of vectors.
24+
A lightweight wrapper for an `AbstractMatrix` to make it behave like a vector of vectors.
25+
Each vector represents a column of the matrix
1726
"""
1827
struct ColVecs{T, TX<:AbstractMatrix{T}, S} <: AbstractVector{S}
1928
X::TX
@@ -27,6 +36,24 @@ Base.size(D::ColVecs) = (size(D.X, 2),)
2736
Base.getindex(D::ColVecs, i::Int) = view(D.X, :, i)
2837
Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
2938

39+
"""
40+
RowVecs(X::AbstractMatrix)
41+
42+
A lightweight wrapper for an `AbstractMatrix` to make it behave like a vector of vectors.
43+
Each vector represents a row of the matrix
44+
"""
45+
struct RowVecs{T, TX<:AbstractMatrix{T}, S} <: AbstractVector{S}
46+
X::TX
47+
function RowVecs(X::TX) where {T, TX<:AbstractMatrix{T}}
48+
S = typeof(view(X, 1, :))
49+
new{T, TX, S}(X)
50+
end
51+
end
52+
53+
Base.size(D::RowVecs) = (size(D.X, 1),)
54+
Base.getindex(D::RowVecs, i::Int) = view(D.X, i, :)
55+
Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :))
56+
3057
# Take highest Float among possibilities
3158
# function promote_float(Tₖ::DataType...)
3259
# if length(Tₖ) == 0

src/zygote_adjoints.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ end
1313
return ColVecs(X), back
1414
end
1515

16+
@adjoint function RowVecs(X::AbstractMatrix)
17+
back::NamedTuple) =.X,)
18+
back::AbstractMatrix) = (Δ,)
19+
function back::AbstractVector{<:AbstractVector{<:Real}})
20+
throw(error("In slow method"))
21+
end
22+
return RowVecs(X), back
23+
end
24+
1625
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
1726
# d = evaluate(s, x, y)
1827
# s = sum(sin.(π*(x-y)))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
using KernelFunctions
12
using Distances
23
using FiniteDifferences
34
using Flux
4-
using KernelFunctions
55
using Kronecker
66
using LinearAlgebra
77
using PDMats

test/utils.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
@testset "utils" begin
2-
using KernelFunctions: ColVecs
3-
rng, N, D = MersenneTwister(123456), 10, 2
2+
using KernelFunctions: vec_of_vecs, ColVecs, RowVecs
3+
rng, N, D = MersenneTwister(123456), 10, 4
44
x, X = randn(rng, N), randn(rng, D, N)
5-
5+
@testset "VecOfVecs" begin
6+
@test vec_of_vecs(X, obsdim = 2) == ColVecs(X)
7+
@test vec_of_vecs(X, obsdim = 1) == RowVecs(X)
8+
end
69
# Test Matrix data sets.
710
@testset "ColVecs" begin
811
DX = ColVecs(X)
@@ -22,6 +25,29 @@
2225
DX, back = Zygote.pullback(ColVecs, X)
2326
@test back((X=ones(size(X)),))[1] == ones(size(X))
2427

28+
@test Zygote.pullback(DX->DX.X, DX)[1] == X
29+
X_, back = Zygote.pullback(DX->DX.X, DX)
30+
@test back(ones(size(X)))[1].X == ones(size(X))
31+
end
32+
end
33+
@testset "RowVecs" begin
34+
DX = RowVecs(X)
35+
@test DX == DX
36+
@test size(DX) == (D,)
37+
@test length(DX) == D
38+
@test getindex(DX, 2) isa AbstractVector
39+
@test getindex(DX, 2) == X[2, :]
40+
@test getindex(DX, 1:3) isa RowVecs
41+
@test getindex(DX, 1:3) == RowVecs(X[1:3, :])
42+
@test getindex(DX, :) == RowVecs(X)
43+
@test eachindex(DX) == 1:D
44+
@test first(DX) == X[1, :]
45+
46+
let
47+
@test Zygote.pullback(RowVecs, X)[1] == DX
48+
DX, back = Zygote.pullback(RowVecs, X)
49+
@test back((X=ones(size(X)),))[1] == ones(size(X))
50+
2551
@test Zygote.pullback(DX->DX.X, DX)[1] == X
2652
X_, back = Zygote.pullback(DX->DX.X, DX)
2753
@test back(ones(size(X)))[1].X == ones(size(X))

0 commit comments

Comments
 (0)