Skip to content

Commit 5c24f1c

Browse files
authored
[WIP] Fix some AD issues with various kernels (#154)
- Defines kernelmatrix function for NeuralNetworkKernel. - Defines Zygote adjoints for Mahalanobis distance metric. - Zygote tests pass for Exponential, FBM, NN and Gabor kernels. * Zygote passes for Exponential and FBM kernel * Zygote passes NN kernel * Zygote passes Gabor kernel * Address code review * Fix mutating arrays problem for maha kernel * Add adjoint for maha distance metric * Fix zygote adjoint * Fix adjoint typo * Fix buggy version of pairwise adjoint * Fix typo * Forgot to add adjoint macro * Add pairwise sqmahalanobis adjoint and test of sqmahalanobis * Maha kernel tests * Fix zygote adjoint for mahalanobis * Fix docs for matern * Make maha tests more readable * Address style issues * Fix bugs in tests and adjoints * Fix maha tests * Remove pairwise maha adjoints for now. * Fix style issues * Update maha.jl * Fix style in zygote_adjoints.jl
1 parent 8c99314 commit 5c24f1c

File tree

10 files changed

+105
-33
lines changed

10 files changed

+105
-33
lines changed

src/basekernels/maha.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
Mahalanobis distance-based kernel given by
55
```math
6-
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
6+
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
77
```
88
where the matrix P is the metric.
99

src/basekernels/matern.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula:
55
```
66
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
77
```
8-
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
8+
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use
9+
[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`,
10+
[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
911
"""
1012
struct MaternKernel{Tν<:Real} <: SimpleKernel
1113
ν::Vector{Tν}

src/basekernels/nn.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
2323
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
2424
end
2525

26+
function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
27+
validate_inputs(x, y)
28+
X_2 = sum(x.X .* x.X; dims=1)
29+
Y_2 = sum(y.X .* y.X; dims=1)
30+
XY = x.X' * y.X
31+
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
32+
end
33+
34+
function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
35+
X_2_1 = sum(x.X .* x.X; dims=1) .+ 1
36+
XX = x.X' * x.X
37+
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
38+
end
39+
40+
function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
41+
validate_inputs(x, y)
42+
X_2 = sum(x.X .* x.X; dims=2)
43+
Y_2 = sum(y.X .* y.X; dims=2)
44+
XY = x.X * y.X'
45+
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
46+
end
47+
48+
function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
49+
X_2_1 = sum(x.X .* x.X; dims=2) .+ 1
50+
XX = x.X * x.X'
51+
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
52+
end
53+
2654
Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")

src/zygote_adjoints.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,21 @@ end
6060
end
6161

6262
@adjoint function ColVecs(X::AbstractMatrix)
63-
back::NamedTuple) =.X,)
64-
back::AbstractMatrix) = (Δ,)
65-
function back::AbstractVector{<:AbstractVector{<:Real}})
63+
ColVecs_pullback::NamedTuple) =.X,)
64+
ColVecs_pullback::AbstractMatrix) = (Δ,)
65+
function ColVecs_pullback::AbstractVector{<:AbstractVector{<:Real}})
6666
throw(error("In slow method"))
6767
end
68-
return ColVecs(X), back
68+
return ColVecs(X), ColVecs_pullback
6969
end
7070

7171
@adjoint function RowVecs(X::AbstractMatrix)
72-
back::NamedTuple) =.X,)
73-
back::AbstractMatrix) = (Δ,)
74-
function back::AbstractVector{<:AbstractVector{<:Real}})
72+
RowVecs_pullback::NamedTuple) =.X,)
73+
RowVecs_pullback::AbstractMatrix) = (Δ,)
74+
function RowVecs_pullback::AbstractVector{<:AbstractVector{<:Real}})
7575
throw(error("In slow method"))
7676
end
77-
return RowVecs(X), back
77+
return RowVecs(X), RowVecs_pullback
7878
end
7979

8080
@adjoint function Base.map(t::Transform, X::ColVecs)
@@ -84,3 +84,13 @@ end
8484
@adjoint function Base.map(t::Transform, X::RowVecs)
8585
pullback(_map, t, X)
8686
end
87+
88+
@adjoint function (dist::Distances.SqMahalanobis)(a, b)
89+
function SqMahalanobis_pullback::Real)
90+
B_Bᵀ = dist.qmat + transpose(dist.qmat)
91+
a_b = a - b
92+
δa = (B_Bᵀ * a_b) * Δ
93+
return (qmat = (a_b * a_b') * Δ,), δa, -δa
94+
end
95+
return evaluate(dist, a, b), SqMahalanobis_pullback
96+
end

test/basekernels/exponential.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
@test metric(GammaExponentialKernel=2.0)) == SqEuclidean()
3939
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
4040
@test KernelFunctions.iskroncompatible(k) == true
41-
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
42-
@test_broken "Zygote gradient given γ"
41+
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ])
4342
test_params(k, ([γ],))
4443
#Coherence :
4544
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)

test/basekernels/fbm.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
2323

2424
@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
25-
test_ADs(FBMKernel, ADs = [:ReverseDiff])
26-
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"
27-
25+
test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
26+
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
2827
test_params(k, ([h],))
2928
end

test/basekernels/gabor.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
@test k.ell 1.0 atol=1e-5
1818
@test k.p 1.0 atol=1e-5
1919
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
20-
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
21-
@test_broken "Tests failing for Zygote on differentiating through ell and p"
20+
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
2221
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
2322
end

test/basekernels/maha.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,40 @@
44
v1 = rand(rng, 3)
55
v2 = rand(rng, 3)
66

7-
P = rand(rng, 3, 3)
7+
U = UpperTriangular(rand(rng, 3,3))
8+
P = Matrix(Cholesky(U, 'U', 0))
9+
@assert isposdef(P)
810
k = MahalanobisKernel(P=P)
911

1012
@test kappa(k, x) == exp(-x)
1113
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
1214
@test kappa(ExponentialKernel(), x) == kappa(k, x)
1315
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
14-
# test_ADs(P -> MahalanobisKernel(P=P), P)
16+
17+
M1, M2 = rand(rng,3,2), rand(rng,3,2)
18+
fdm = FiniteDifferences.Central(5, 1);
19+
20+
21+
function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64})
22+
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
23+
end
24+
a = rand()
25+
26+
function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
27+
return MahalanobisKernel(P=Array(U'*U))(v1, v2)
28+
end
29+
30+
@test all(FiniteDifferences.j′vp(fdm, test_mahakernel, a, U, v1, v2)[1] .≈
31+
UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1]))
32+
33+
function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
34+
return SqMahalanobis(Array(U'*U))(v1, v2)
35+
end
36+
37+
@test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈
38+
UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1]))
39+
40+
# test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote])
1541
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
1642

1743
test_params(k, (P,))

test/basekernels/nn.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,10 @@
3838
@test kerneldiagmatrix(k, m1) A4 atol=1e-5
3939

4040
A5 = ones(4,4)
41-
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3)
42-
@test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3)
41+
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3)
42+
@test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3)
4343
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))
4444

4545
@test k([x1], [x2]) k(x1, x2) atol=1e-5
46-
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
47-
@test_broken "Zygote uncompatible with BaseKernel"
46+
test_ADs(NeuralNetworkKernel)
4847
end

test/zygote_adjoints.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,53 @@
44
x = rand(rng, 5)
55
y = rand(rng, 5)
66
r = rand(rng, 5)
7+
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
8+
@assert isposdef(Q)
79

8-
gzeucl = gradient(:Zygote, [x,y]) do xy
10+
11+
gzeucl = gradient(:Zygote, [x, y]) do xy
912
evaluate(Euclidean(), xy[1], xy[2])
1013
end
11-
gzsqeucl = gradient(:Zygote, [x,y]) do xy
14+
gzsqeucl = gradient(:Zygote, [x, y]) do xy
1215
evaluate(SqEuclidean(), xy[1], xy[2])
1316
end
14-
gzdotprod = gradient(:Zygote, [x,y]) do xy
17+
gzdotprod = gradient(:Zygote, [x, y]) do xy
1518
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
1619
end
17-
gzdelta = gradient(:Zygote, [x,y]) do xy
20+
gzdelta = gradient(:Zygote, [x, y]) do xy
1821
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
1922
end
20-
gzsinus = gradient(:Zygote, [x,y]) do xy
23+
gzsinus = gradient(:Zygote, [x, y]) do xy
2124
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
2225
end
26+
gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy
27+
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
28+
end
2329

24-
gfeucl = gradient(:FiniteDiff, [x,y]) do xy
30+
gfeucl = gradient(:FiniteDiff, [x, y]) do xy
2531
evaluate(Euclidean(), xy[1], xy[2])
2632
end
27-
gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy
33+
gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy
2834
evaluate(SqEuclidean(), xy[1], xy[2])
2935
end
30-
gfdotprod = gradient(:FiniteDiff, [x,y]) do xy
36+
gfdotprod = gradient(:FiniteDiff, [x, y]) do xy
3137
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
3238
end
33-
gfdelta = gradient(:FiniteDiff, [x,y]) do xy
39+
gfdelta = gradient(:FiniteDiff, [x, y]) do xy
3440
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
3541
end
36-
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
42+
gfsinus = gradient(:FiniteDiff, [x, y]) do xy
3743
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
3844
end
45+
gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy
46+
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
47+
end
3948

4049

4150
@test all(gzeucl .≈ gfeucl)
4251
@test all(gzsqeucl .≈ gfsqeucl)
4352
@test all(gzdotprod .≈ gfdotprod)
4453
@test all(gzdelta .≈ gfdelta)
4554
@test all(gzsinus .≈ gfsinus)
55+
@test all(gzsqmaha .≈ gfsqmaha)
4656
end

0 commit comments

Comments
 (0)