Skip to content

Commit c09ec42

Browse files
authored
Merge pull request JuliaGaussianProcesses#84 from theogf/colvecs
Added ColVecs
2 parents 89baf9c + b5d0f9f commit c09ec42

File tree

6 files changed

+68
-2
lines changed

6 files changed

+68
-2
lines changed

src/transform/transform.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ Return exactly the input
2121
"""
2222
struct IdentityTransform <: Transform end
2323

24-
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x
24+
apply(t::IdentityTransform, x; obsdim::Int = defaultobs) = x
25+
26+
apply(t::Transform, x::ColVecs; obsdim::Int = defaultobs) = ColVecs(apply(t, x.X, obsdim = 2))
2527

2628
### TODO Maybe defining adjoints could help but so far it's not working
2729

src/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@ macro check_args(K, param, cond, desc=string(cond))
1010
end
1111

1212

13+
"""
14+
ColVecs(X::AbstractMatrix)
15+
16+
A lightweight box for an `AbstractMatrix` to make it behave like a vector of vectors.
17+
"""
18+
struct ColVecs{T, TX<:AbstractMatrix{T}, S} <: AbstractVector{S}
19+
X::TX
20+
function ColVecs(X::TX) where {T, TX<:AbstractMatrix{T}}
21+
S = typeof(view(X, :, 1))
22+
new{T, TX, S}(X)
23+
end
24+
end
25+
26+
Base.size(D::ColVecs) = (size(D.X, 2),)
27+
Base.getindex(D::ColVecs, i::Int) = view(D.X, :, i)
28+
Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
29+
1330
# Take highest Float among possibilities
1431
# function promote_float(Tₖ::DataType...)
1532
# if length(Tₖ) == 0

src/zygote_adjoints.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
end
55
end
66

7+
@adjoint function ColVecs(X::AbstractMatrix)
8+
back::NamedTuple) =.X,)
9+
back::AbstractMatrix) = (Δ,)
10+
function back::AbstractVector{<:AbstractVector{<:Real}})
11+
throw(error("In slow method"))
12+
end
13+
return ColVecs(X), back
14+
end
15+
716
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
817
# d = evaluate(s, x, y)
918
# s = sum(sin.(π*(x-y)))

test/trainable.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "trainable" begin
2+
using Flux: params
23
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5; r = rand(3)
34

45
kc = ConstantKernel(c=c)

test/transform/transform.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
rng = MersenneTwister(123546)
44
X = rand(rng, dims...)
55
@testset "IdentityTransform" begin
6-
@test KernelFunctions.apply(IdentityTransform(),X)==X
6+
@test KernelFunctions.apply(IdentityTransform(), X) == X
7+
end
8+
@testset "ColVecs" begin
9+
vX = KernelFunctions.ColVecs(X)
10+
t = ARDTransform(rand(dims[1]))
11+
@test KernelFunctions.apply(t, vX) KernelFunctions.ColVecs(KernelFunctions.apply(t, X, obsdim = 2))
12+
13+
Y = rand(rng, reverse(dims)...)
14+
vY = KernelFunctions.ColVecs(Y')
15+
t = ARDTransform(rand(dims[1]))
16+
@test KernelFunctions.apply(t, vY) KernelFunctions.ColVecs(KernelFunctions.apply(t, Y, obsdim = 1)')
717
end
818
end

test/utils.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
11
@testset "utils" begin
2+
using KernelFunctions: ColVecs
3+
rng, N, D = MersenneTwister(123456), 10, 2
4+
x, X = randn(rng, N), randn(rng, D, N)
25

6+
# Test Matrix data sets.
7+
@testset "ColVecs" begin
8+
DX = ColVecs(X)
9+
@test DX == DX
10+
@test size(DX) == (N,)
11+
@test length(DX) == N
12+
@test getindex(DX, 5) isa AbstractVector
13+
@test getindex(DX, 5) == X[:, 5]
14+
@test getindex(DX, 1:2:6) isa ColVecs
15+
@test getindex(DX, 1:2:6) == ColVecs(X[:, 1:2:6])
16+
@test getindex(DX, :) == ColVecs(X)
17+
@test eachindex(DX) == 1:N
18+
@test first(DX) == X[:, 1]
19+
20+
let
21+
@test Zygote.pullback(ColVecs, X)[1] == DX
22+
DX, back = Zygote.pullback(ColVecs, X)
23+
@test back((X=ones(size(X)),))[1] == ones(size(X))
24+
25+
@test Zygote.pullback(DX->DX.X, DX)[1] == X
26+
X_, back = Zygote.pullback(DX->DX.X, DX)
27+
@test back(ones(size(X)))[1].X == ones(size(X))
28+
end
29+
end
330
end

0 commit comments

Comments
 (0)