Skip to content

[WIP] Fix AD issues with various kernels #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Mahalanobis distance-based kernel given by
```math
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
```
where the matrix P is the metric.

Expand Down
4 changes: 3 additions & 1 deletion src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula:
```
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
```
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use
[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`,
[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
"""
struct MaternKernel{Tν<:Real} <: SimpleKernel
ν::Vector{Tν}
Expand Down
28 changes: 28 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=1)
Y_2 = sum(y.X .* y.X; dims=1)
XY = x.X' * y.X
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
X_2_1 = sum(x.X .* x.X; dims=1) .+ 1
XX = x.X' * x.X
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=2)
Y_2 = sum(y.X .* y.X; dims=2)
XY = x.X * y.X'
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
X_2_1 = sum(x.X .* x.X; dims=2) .+ 1
XX = x.X * x.X'
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
26 changes: 18 additions & 8 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,21 @@ end
end

@adjoint function ColVecs(X::AbstractMatrix)
back(Δ::NamedTuple) = (Δ.X,)
back(Δ::AbstractMatrix) = (Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
ColVecs_pullback(Δ::NamedTuple) = (Δ.X,)
ColVecs_pullback(Δ::AbstractMatrix) = (Δ,)
function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return ColVecs(X), back
return ColVecs(X), ColVecs_pullback
end

@adjoint function RowVecs(X::AbstractMatrix)
back(Δ::NamedTuple) = (Δ.X,)
back(Δ::AbstractMatrix) = (Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
RowVecs_pullback(Δ::NamedTuple) = (Δ.X,)
RowVecs_pullback(Δ::AbstractMatrix) = (Δ,)
function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return RowVecs(X), back
return RowVecs(X), RowVecs_pullback
end

@adjoint function Base.map(t::Transform, X::ColVecs)
Expand All @@ -84,3 +84,13 @@ end
@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end

@adjoint function (dist::Distances.SqMahalanobis)(a, b)
function SqMahalanobis_pullback(Δ::Real)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = (B_Bᵀ * a_b) * Δ
return (qmat = (a_b * a_b') * Δ,), δa, -δa
end
return evaluate(dist, a, b), SqMahalanobis_pullback
end
3 changes: 1 addition & 2 deletions test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
@test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ])
test_params(k, ([γ],))
#Coherence :
@test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2)
Expand Down
5 changes: 2 additions & 3 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5

@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
test_ADs(FBMKernel, ADs = [:ReverseDiff])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"

test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
test_params(k, ([h],))
end
3 changes: 1 addition & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
@test k.ell ≈ 1.0 atol=1e-5
@test k.p ≈ 1.0 atol=1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Tests failing for Zygote on differentiating through ell and p"
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
end
30 changes: 28 additions & 2 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,40 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)

P = rand(rng, 3, 3)
U = UpperTriangular(rand(rng, 3,3))
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)
k = MahalanobisKernel(P=P)

@test kappa(k, x) == exp(-x)
@test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
# test_ADs(P -> MahalanobisKernel(P=P), P)

M1, M2 = rand(rng,3,2), rand(rng,3,2)
fdm = FiniteDifferences.Central(5, 1);


function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64})
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
end
a = rand()

function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return MahalanobisKernel(P=Array(U'*U))(v1, v2)
end

@test all(FiniteDifferences.j′vp(fdm, test_mahakernel, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1]))

function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return SqMahalanobis(Array(U'*U))(v1, v2)
end

@test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1]))

# test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"

test_params(k, (P,))
Expand Down
7 changes: 3 additions & 4 deletions test/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
@test kerneldiagmatrix(k, m1) ≈ A4 atol=1e-5

A5 = ones(4,4)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3)
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))

@test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote uncompatible with BaseKernel"
test_ADs(NeuralNetworkKernel)
end
30 changes: 20 additions & 10 deletions test/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,53 @@
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)

gzeucl = gradient(:Zygote, [x,y]) do xy

gzeucl = gradient(:Zygote, [x, y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
end
gzsqeucl = gradient(:Zygote, [x,y]) do xy
gzsqeucl = gradient(:Zygote, [x, y]) do xy
evaluate(SqEuclidean(), xy[1], xy[2])
end
gzdotprod = gradient(:Zygote, [x,y]) do xy
gzdotprod = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
end
gzdelta = gradient(:Zygote, [x,y]) do xy
gzdelta = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
end
gzsinus = gradient(:Zygote, [x,y]) do xy
gzsinus = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end

gfeucl = gradient(:FiniteDiff, [x,y]) do xy
gfeucl = gradient(:FiniteDiff, [x, y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
end
gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy
gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy
evaluate(SqEuclidean(), xy[1], xy[2])
end
gfdotprod = gradient(:FiniteDiff, [x,y]) do xy
gfdotprod = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
end
gfdelta = gradient(:FiniteDiff, [x,y]) do xy
gfdelta = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
end
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
gfsinus = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end


@test all(gzeucl .≈ gfeucl)
@test all(gzsqeucl .≈ gfsqeucl)
@test all(gzdotprod .≈ gfdotprod)
@test all(gzdelta .≈ gfdelta)
@test all(gzsinus .≈ gfsinus)
@test all(gzsqmaha .≈ gfsqmaha)
end