Skip to content

Commit 1831cc6

Browse files
ChainTransform AD performance (#466)
* Tuples rather than vectors * Testing * Bump patch * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Improve explanation of heuristic Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b5af459 commit 1831cc6

File tree

4 files changed

+100
-13
lines changed

4 files changed

+100
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.41"
3+
version = "0.10.42"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/transform/chaintransform.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ChainTransform(ts::AbstractVector{<:Transform})
2+
ChainTransform(transforms)
33
44
Transformation that applies a chain of transformations `ts` to the input.
55
@@ -19,7 +19,7 @@ julia> map(t2 ∘ t1, ColVecs(X)) == ColVecs(A * (l .* X))
1919
true
2020
```
2121
"""
22-
struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform
22+
struct ChainTransform{V} <: Transform
2323
transforms::V
2424
end
2525

@@ -28,23 +28,23 @@ end
2828
Base.length(t::ChainTransform) = length(t.transforms)
2929

3030
# Constructor to create a chain transform with an array of parameters
31-
function ChainTransform(v::AbstractVector{<:Type{<:Transform}}, θ::AbstractVector)
31+
function ChainTransform(v, θ::AbstractVector)
3232
@assert length(v) == length(θ)
3333
return ChainTransform(v.(θ))
3434
end
3535

36-
Base.:(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁])
37-
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t))
38-
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms))
36+
Base.:(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
37+
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t))
38+
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...))
3939

4040
(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)
4141

4242
function _map(t::ChainTransform, x::AbstractVector)
43-
return foldl((x, t) -> map(t, x), t.transforms; init=x)
43+
return foldl((x, t) -> _map(t, x), t.transforms; init=x)
4444
end
4545

4646
set!(t::ChainTransform, θ) = set!.(t.transforms, θ)
47-
duplicate(t::ChainTransform, θ) = ChainTransform(duplicate.(t.transforms, θ))
47+
duplicate(t::ChainTransform, θ) = ChainTransform(map(duplicate, t.transforms, θ))
4848

4949
Base.show(io::IO, t::ChainTransform) = printshifted(io, t, 0)
5050

test/test_utils.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,81 @@ function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3))
274274
end
275275
end
276276
end
277+
278+
function count_allocs(f, args...)
279+
stats = @timed f(args...)
280+
return Base.gc_alloc_count(stats.gcstats)
281+
end
282+
283+
"""
284+
constant_allocs_heuristic(f, args1::T, args2::T) where {T}
285+
286+
True if number of allocations associated with evaluating `f(args1...)` is equal to those
287+
required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related
288+
allocations are not included.
289+
290+
Why is this a good test? In lots of situations it will be the case that the total amount of
291+
memory allocated by a function will vary as the input sizes vary, but the total _number_
292+
of allocations ought to be constant. A common performance bug is that the number of
293+
allocations actually does scale with the size of the inputs (e.g. due to a type
294+
instability), and we would very much like to know if this is happening.
295+
296+
Typically this kind of condition is not a sufficient condition for good performance, but it
297+
is certainly a necessary condition.
298+
299+
This kind of test is very quick to conduct (just requires running `f` 4 times). It's also
300+
easier to write than simply checking that the total number of allocations used to execute
301+
a function is below some arbitrary `f`-dependent threshold.
302+
"""
303+
function constant_allocs_heuristic(f, args1::T, args2::T) where {T}
304+
305+
# Ensure that we're not counting allocations associated with compilation.
306+
f(args1...)
307+
f(args2...)
308+
309+
allocs_1 = count_allocs(f, args1...)
310+
allocs_2 = count_allocs(f, args2...)
311+
return allocs_1 == allocs_2
312+
end
313+
314+
"""
315+
ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T}
316+
317+
Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its
318+
pullback for both of `args1` and `args2`.
319+
320+
`Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)`
321+
and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that
322+
the output of the primal is an acceptable cotangent to be passed to the corresponding
323+
pullback.
324+
"""
325+
function ad_constant_allocs_heuristic(
326+
f, args1::T, args2::T; Δ1=nothing, Δ2=nothing
327+
) where {T}
328+
329+
# Check that primal has constant allocations.
330+
primal_heuristic = constant_allocs_heuristic(f, args1, args2)
331+
332+
# Check that forwards-pass has constant allocations.
333+
forwards_heuristic = constant_allocs_heuristic(
334+
(args...) -> Zygote.pullback(f, args...), args1, args2
335+
)
336+
337+
# Check that pullback has constant allocations for both arguments. Run twice to remove
338+
# compilation-related allocations.
339+
340+
# First thing
341+
out1, pb1 = Zygote.pullback(f, args1...)
342+
Δ1_val = Δ1 === nothing ? out1 : Δ1
343+
pb1(Δ1_val)
344+
allocs_1 = count_allocs(pb1, Δ1_val)
345+
346+
# Second thing
347+
out2, pb2 = Zygote.pullback(f, args2...)
348+
Δ2_val = Δ2 === nothing ? out2 : Δ2
349+
pb2(Δ2_val)
350+
allocs_2 = count_allocs(pb2, Δ2 === nothing ? out2 : Δ2)
351+
352+
pullback_heuristic = allocs_1 == allocs_2
353+
return primal_heuristic, forwards_heuristic, pullback_heuristic
354+
end

test/transform/chaintransform.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
f(x) = sin.(x)
88
tf = FunctionTransform(f)
99

10-
t = ChainTransform([tp, tf])
10+
t = ChainTransform((tp, tf))
1111

1212
# Check composition constructors.
13-
@test (tf ChainTransform([tp])).transforms == [tp, tf]
14-
@test (ChainTransform([tf]) tp).transforms == [tp, tf]
13+
@test (tf ChainTransform([tp])).transforms == (tp, tf)
14+
@test (ChainTransform([tf]) tp).transforms == (tp, tf)
1515

1616
# Verify correctness.
1717
x = ColVecs(randn(rng, 2, 3))
@@ -27,5 +27,14 @@
2727
randn(rng, 4);
2828
ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote
2929
)
30-
@test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263"
30+
31+
@testset "AD performance" begin
32+
primal, forward, pb = ad_constant_allocs_heuristic((randn(5),), (randn(10),)) do x
33+
k = SEKernel() (ScaleTransform(0.1) PeriodicTransform(10.0))
34+
return kernelmatrix(k, x)
35+
end
36+
@test primal
37+
@test forward
38+
@test pb
39+
end
3140
end

0 commit comments

Comments
 (0)