Skip to content

Commit 979a019

Browse files
willtebbuttgithub-actions[bot]devmotion
authored
Implement Zygote hack (#376)
* Implement Zygote hack * Update src/KernelFunctions.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl Co-authored-by: David Widmann <[email protected]> * Further tests * Update test/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/test_utils.jl * Update test/utils.jl * Bump patch version * Restrict testing to 1.6 Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent 16f1149 commit 979a019

File tree

7 files changed

+53
-12
lines changed

7 files changed

+53
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.21"
3+
version = "0.10.22"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/KernelFunctions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
5858
using LogExpFunctions: softplus
5959
using StatsBase
6060
using TensorCore
61-
using ZygoteRules: ZygoteRules
61+
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
62+
63+
# Hack to work around Zygote type inference problems.
64+
const Distances_pairwise = Distances.pairwise
6265

6366
abstract type Kernel end
6467
abstract type SimpleKernel <: Kernel end

src/distances/pairwise.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ end
1313
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)
1414

1515
function pairwise(d::PreMetric, x::AbstractVector{<:Real})
16-
return Distances.pairwise(d, reshape(x, :, 1); dims=1)
16+
return Distances_pairwise(d, reshape(x, :, 1); dims=1)
1717
end
1818

1919
function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
20-
return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
20+
return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
2121
end
2222

2323
function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})

src/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X))
8080

8181
dim(x::ColVecs) = size(x.X, 1)
8282

83-
pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2)
84-
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2)
83+
pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
84+
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
8585
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
86-
return Distances.pairwise(d, reduce(hcat, x), y.X; dims=2)
86+
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
8787
end
8888
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
89-
return Distances.pairwise(d, x.X, reduce(hcat, y); dims=2)
89+
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
9090
end
9191
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
9292
return Distances.pairwise!(out, d, x.X; dims=2)
@@ -150,13 +150,13 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X))
150150

151151
dim(x::RowVecs) = size(x.X, 2)
152152

153-
pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1)
154-
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1)
153+
pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
154+
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
155155
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
156-
return Distances.pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
156+
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
157157
end
158158
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
159-
return Distances.pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
159+
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
160160
end
161161
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
162162
return Distances.pairwise!(out, d, x.X; dims=1)

src/zygoterules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ end
55
ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs)
66
return ZygoteRules.pullback(_map, t, X)
77
end
8+
9+
function ZygoteRules._pullback(
10+
cx::AContext, ::typeof(literal_getproperty), x::ColVecs, ::Val{f}
11+
) where {f}
12+
return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}())
13+
end

test/test_utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ end
3131

3232
# AD utilities
3333

34+
# Type to work around some performance issues that can happen on the reverse-pass of Zygote.
35+
# This context doesn't allow any globals. Don't use this if you use globals in your
36+
# programme.
37+
struct NoContext <: Zygote.AContext end
38+
39+
Zygote.cache(cx::NoContext) = (cache_fields = nothing)
40+
Base.haskey(cx::NoContext, x) = false
41+
Zygote.accum_param(::NoContext, x, Δ) = Δ
42+
3443
const FDM = FiniteDifferences.central_fdm(5, 1)
3544

3645
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)
@@ -87,6 +96,13 @@ function test_ADs(
8796
end
8897
end
8998

99+
function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
100+
@inferred f(args...)
101+
@inferred Zygote._pullback(ctx, f, args...)
102+
out, pb = Zygote._pullback(ctx, f, args...)
103+
@inferred pb(out)
104+
end
105+
90106
function test_ADs(
91107
k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3)
92108
)

test/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@
5151
X_, back = Zygote.pullback(DX -> DX.X, DX)
5252
@test back(ones(size(X)))[1].X == ones(size(X))
5353
end
54+
55+
if VERSION >= v"1.6"
56+
@testset "Zygote type-inference" begin
57+
ctx = NoContext()
58+
x = ColVecs(randn(2, 4))
59+
y = ColVecs(randn(2, 3))
60+
61+
# Ensure KernelFunctions.pairwise rather than Distances.pairwise is used.
62+
check_zygote_type_stability(
63+
x -> KernelFunctions.pairwise(SqEuclidean(), x), x; ctx=ctx
64+
)
65+
check_zygote_type_stability(
66+
(x, y) -> KernelFunctions.pairwise(SqEuclidean(), x, y), x, y; ctx=ctx
67+
)
68+
end
69+
end
5470
end
5571
@testset "RowVecs" begin
5672
DX = RowVecs(X)

0 commit comments

Comments
 (0)