Skip to content

Commit fb37557

Browse files
authored
Merge pull request #114 from theogf/test_AD
Series of tests for AD
2 parents 3b0cf61 + e94973e commit fb37557

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+375
-229
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
- name: CompatHelper.main()
1717
env:
1818
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
19-
run: julia -e 'using CompatHelper; CompatHelper.main()'
19+
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "test"])'

Project.toml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,3 @@ StatsBase = "0.32, 0.33"
2222
StatsFuns = "0.8, 0.9"
2323
ZygoteRules = "0.2"
2424
julia = "1.3"
25-
26-
[extras]
27-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
29-
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
30-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
31-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
32-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
33-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
34-
35-
[targets]
36-
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3434
using Compat
3535
using Requires
3636
using Distances, LinearAlgebra
37-
using SpecialFunctions: logabsgamma, besselk
38-
using ZygoteRules: @adjoint
37+
using SpecialFunctions: logabsgamma, besselk, polygamma
38+
using ZygoteRules: @adjoint, pullback
3939
using StatsFuns: logtwo
4040
using InteractiveUtils: subtypes
4141
using StatsBase

src/basekernels/matern.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ end
1717

1818
@inline function kappa::MaternKernel, d::Real)
1919
ν = first.ν)
20-
iszero(d) ? one(d) :
21-
exp(
22-
(one(d) - ν) * logtwo - logabsgamma(ν)[1] +
23-
ν * log(sqrt(2ν) * d) +
24-
log(besselk(ν, sqrt(2ν) * d))
25-
)
20+
iszero(d) ? one(d) : _matern(ν, d)
21+
end
22+
23+
function _matern::Real, d::Real)
24+
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
2625
end
2726

2827
metric(::MaternKernel) = Euclidean()

src/distances/delta.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
struct Delta <: Distances.PreMetric
22
end
33

4-
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return a == b
99
end
1010

11+
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
12+
1113
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
12-
@inline (dist::Delta)(a::Number,b::Number) = a == b
14+
@inline (dist::Delta)(a::Number, b::Number) = a == b

src/distances/dotproduct.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
struct DotProduct <: Distances.PreMetric end
22
# struct DotProduct <: Distances.UnionSemiMetric end
33

4-
@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return dot(a,b)
99
end
1010

11+
Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
12+
1113
@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
1214
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
1315
@inline (dist::DotProduct)(a::Number,b::Number) = a * b

src/distances/sinus.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r
88
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
99
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
1010

11-
@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
11+
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)
12+
13+
@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T}
1214
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
1315
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
1416
end

src/transform/ardtransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ dim(t::ARDTransform) = length(t.v)
2424
(t::ARDTransform)(x::Real) = first(t.v) * x
2525
(t::ARDTransform)(x) = t.v .* x
2626

27-
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
28-
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29-
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
27+
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
28+
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29+
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
3030

3131
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
3232

src/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor
2727

2828
(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)
2929

30-
function Base.map(t::ChainTransform, x::AbstractVector)
30+
function _map(t::ChainTransform, x::AbstractVector)
3131
return foldl((x, t) -> map(t, x), t.transforms; init=x)
3232
end
3333

src/transform/functiontransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ end
1515

1616
(t::FunctionTransform)(x) = t.f(x)
1717

18-
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19-
Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20-
Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
18+
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19+
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20+
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
2121

2222
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
2323

src/transform/lineartransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ end
2727
(t::LinearTransform)(x::Real) = vec(t.A * x)
2828
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
2929

30-
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
31-
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32-
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
30+
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
31+
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32+
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
3333

3434
function Base.show(io::IO, t::LinearTransform)
3535
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")

src/transform/scaletransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ]
1919

2020
(t::ScaleTransform)(x) = first(t.s) .* x
2121

22-
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
23-
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
24-
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
22+
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
23+
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
24+
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
2525

2626
Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))
2727

src/transform/selecttransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ duplicate(t::SelectTransform,θ) = t
2525

2626
(t::SelectTransform)(x::AbstractVector) = view(x, t.select)
2727

28-
Base.map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
29-
Base.map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
28+
_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
29+
_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
3030

3131
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")

src/transform/transform.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@ include("functiontransform.jl")
55
include("selecttransform.jl")
66
include("chaintransform.jl")
77

8-
"""
9-
apply(t::Transform, x; obsdim::Int=defaultobs)
108

11-
Apply the transform `t` vector-wise on the array `x`
12-
"""
13-
apply
9+
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
1410

1511
"""
1612
IdentityTransform()
@@ -20,7 +16,7 @@ Return exactly the input
2016
struct IdentityTransform <: Transform end
2117

2218
(t::IdentityTransform)(x) = x
23-
Base.map(::IdentityTransform, x::AbstractVector) = x
19+
_map(::IdentityTransform, x::AbstractVector) = x
2420

2521
### TODO Maybe defining adjoints could help but so far it's not working
2622

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
hadamard(x, y) = x .* y
22

3+
loggamma(x) = first(logabsgamma(x))
4+
35
# Macro for checking arguments
46
macro check_args(K, param, cond, desc=string(cond))
57
quote
@@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector)
124126
))
125127
end
126128
end
127-

src/zygote_adjoints.jl

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,78 @@
1+
## Adjoints Delta
2+
@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector)
3+
evaluate(s, x, y), Δ -> begin
4+
(nothing, nothing, nothing)
5+
end
6+
end
7+
8+
@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9+
D = pairwise(d, X, Y; dims = dims)
10+
if dims == 1
11+
return D, Δ -> (nothing, nothing, nothing)
12+
else
13+
return D, Δ -> (nothing, nothing, nothing)
14+
end
15+
end
16+
17+
@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
18+
D = pairwise(d, X; dims = dims)
19+
if dims == 1
20+
return D, Δ -> (nothing, nothing)
21+
else
22+
return D, Δ -> (nothing, nothing)
23+
end
24+
end
25+
26+
## Adjoints DotProduct
127
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
228
dot(x, y), Δ -> begin
329
(nothing, Δ .* y, Δ .* x)
430
end
531
end
632

33+
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
34+
D = pairwise(d, X, Y; dims = dims)
35+
if dims == 1
36+
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
37+
else
38+
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
39+
end
40+
end
41+
42+
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
43+
D = pairwise(d, X; dims = dims)
44+
if dims == 1
45+
return D, Δ -> (nothing, 2 * Δ * X)
46+
else
47+
return D, Δ -> (nothing, 2 * X * Δ)
48+
end
49+
end
50+
51+
## Adjoints Sinus
52+
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
53+
d = (x - y)
54+
sind = sinpi.(d)
55+
val = sum(abs2, sind ./ s.r)
56+
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
57+
val, Δ -> begin
58+
((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx)
59+
end
60+
end
61+
62+
@adjoint function loggamma(x)
63+
first(logabsgamma(x)) , Δ ->.* polygamma(0, x), )
64+
end
65+
66+
@adjoint function kappa::MaternKernel, d::Real)
67+
ν = first.ν)
68+
val, grad = pullback(_matern, ν, d)
69+
return ((iszero(d) ? one(d) : val),
70+
Δ -> begin
71+
= grad(Δ)
72+
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2])
73+
end)
74+
end
75+
776
@adjoint function ColVecs(X::AbstractMatrix)
877
back::NamedTuple) =.X,)
978
back::AbstractMatrix) = (Δ,)
@@ -22,10 +91,10 @@ end
2291
return RowVecs(X), back
2392
end
2493

25-
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
26-
# d = evaluate(s, x, y)
27-
# s = sum(sin.(π*(x-y)))
28-
# d, Δ -> begin
29-
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d)
30-
# end
31-
# end
94+
@adjoint function Base.map(t::Transform, X::ColVecs)
95+
pullback(_map, t, X)
96+
end
97+
98+
@adjoint function Base.map(t::Transform, X::RowVecs)
99+
pullback(_map, t, X)
100+
end

test/Project.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[deps]
2+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
11+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
12+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
14+
15+
[compat]
16+
Distances = "0.9"
17+
FiniteDifferences = "0.10"
18+
Flux = "0.10"
19+
ForwardDiff = "0.10"
20+
Kronecker = "0.4"
21+
PDMats = "0.9"
22+
ReverseDiff = "1.2"
23+
SpecialFunctions = "0.10"
24+
Zygote = "0.4"

test/basekernels/constant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
@test kappa(k,2.0) == 0.0
66
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
77
@test repr(k) == "Zero Kernel"
8+
test_ADs(ZeroKernel)
89
end
910
@testset "WhiteKernel" begin
1011
k = WhiteKernel()
@@ -14,6 +15,7 @@
1415
@test EyeKernel == WhiteKernel
1516
@test metric(WhiteKernel()) == KernelFunctions.Delta()
1617
@test repr(k) == "White Kernel"
18+
test_ADs(WhiteKernel)
1719
end
1820
@testset "ConstantKernel" begin
1921
c = 2.0
@@ -24,5 +26,6 @@
2426
@test metric(ConstantKernel()) == KernelFunctions.Delta()
2527
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
2628
@test repr(k) == "Constant Kernel (c = $(c))"
29+
test_ADs(c->ConstantKernel(c=first(c)), [c])
2730
end
2831
end

test/basekernels/cosine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
@test kappa(k,x) cospi(x) atol=1e-5
1313
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
1414
@test repr(k) == "Cosine Kernel"
15+
test_ADs(CosineKernel)
1516
end

test/basekernels/exponential.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
@test SEKernel == SqExponentialKernel
1515
@test repr(k) == "Squared Exponential Kernel"
1616
@test KernelFunctions.iskroncompatible(k) == true
17+
test_ADs(SEKernel)
1718
end
1819
@testset "ExponentialKernel" begin
1920
k = ExponentialKernel()
@@ -24,6 +25,7 @@
2425
@test repr(k) == "Exponential Kernel"
2526
@test LaplacianKernel == ExponentialKernel
2627
@test KernelFunctions.iskroncompatible(k) == true
28+
test_ADs(ExponentialKernel)
2729
end
2830
@testset "GammaExponentialKernel" begin
2931
γ = 2.0
@@ -36,7 +38,8 @@
3638
@test metric(GammaExponentialKernel=2.0)) == SqEuclidean()
3739
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
3840
@test KernelFunctions.iskroncompatible(k) == true
39-
41+
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
42+
@test_broken "Zygote gradient given γ"
4043
#Coherence :
4144
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)
4245
@test GammaExponentialKernel=0.5)(v1,v2) ExponentialKernel()(v1,v2)

test/basekernels/exponentiated.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
@test k(v1,v2) exp(dot(v1,v2))
1111
@test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct()
1212
@test repr(k) == "Exponentiated Kernel"
13+
test_ADs(ExponentiatedKernel)
1314
end

0 commit comments

Comments
 (0)