-
Notifications
You must be signed in to change notification settings - Fork 36
ChainTransform AD performance #466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,3 +274,81 @@ function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3)) | |
end | ||
end | ||
end | ||
|
||
function count_allocs(f, args...) | ||
stats = @timed f(args...) | ||
return Base.gc_alloc_count(stats.gcstats) | ||
end | ||
|
||
""" | ||
constant_allocs_heuristic(f, args1::T, args2::T) where {T} | ||
|
||
True if number of allocations associated with evaluating `f(args1...)` is equal to those | ||
required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related | ||
allocations are not included. | ||
|
||
Why is this a good test? In lots of situations it will be the case that the total amount of | ||
memory allocated by a function will vary as the input sizes vary, but the total _number_ | ||
of allocations ought to be constant. A common performance bug is that the number of | ||
allocations actually does scale with the size of the inputs (e.g. due to a type | ||
instability), and we would very much like to know if this is happening. | ||
|
||
Typically this kind of condition is not a sufficient condition for good performance, but it | ||
is certainly a necessary condition. | ||
|
||
This kind of test is very quick to conduct (just requires running `f` 4 times). It's also | ||
easier to write than simply checking that the total number of allocations used to execute | ||
a function is below some arbitrary `f`-dependent threshold. | ||
""" | ||
function constant_allocs_heuristic(f, args1::T, args2::T) where {T} | ||
|
||
# Ensure that we're not counting allocations associated with compilation. | ||
f(args1...) | ||
f(args2...) | ||
|
||
allocs_1 = count_allocs(f, args1...) | ||
allocs_2 = count_allocs(f, args2...) | ||
return allocs_1 == allocs_2 | ||
end | ||
|
||
""" | ||
ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T} | ||
|
||
Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its | ||
pullback for both of `args1` and `args2`. | ||
|
||
`Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)` | ||
and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that | ||
the output of the primal is an acceptable cotangent to be passed to the corresponding | ||
pullback. | ||
""" | ||
function ad_constant_allocs_heuristic( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid I don't really understand why we want to check that the number of allocations is equal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. I should add more docs then. The logic is:
It's definitely not a sufficient condition for us to know that we've got good performance, but we at least know that the number of allocations is independent of the size of the output of Does this make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this is much clearer thanks! If you put this in the docstring we are good to go :) |
||
f, args1::T, args2::T; Δ1=nothing, Δ2=nothing | ||
) where {T} | ||
|
||
# Check that primal has constant allocations. | ||
primal_heuristic = constant_allocs_heuristic(f, args1, args2) | ||
|
||
# Check that forwards-pass has constant allocations. | ||
forwards_heuristic = constant_allocs_heuristic( | ||
(args...) -> Zygote.pullback(f, args...), args1, args2 | ||
) | ||
|
||
# Check that pullback has constant allocations for both arguments. Run twice to remove | ||
# compilation-related allocations. | ||
|
||
# First thing | ||
out1, pb1 = Zygote.pullback(f, args1...) | ||
Δ1_val = Δ1 === nothing ? out1 : Δ1 | ||
pb1(Δ1_val) | ||
allocs_1 = count_allocs(pb1, Δ1_val) | ||
|
||
# Second thing | ||
out2, pb2 = Zygote.pullback(f, args2...) | ||
Δ2_val = Δ2 === nothing ? out2 : Δ2 | ||
pb2(Δ2_val) | ||
allocs_2 = count_allocs(pb2, Δ2 === nothing ? out2 : Δ2) | ||
|
||
pullback_heuristic = allocs_1 == allocs_2 | ||
return primal_heuristic, forwards_heuristic, pullback_heuristic | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we could add a constructor
ChainTransform(transforms...) = ChainTransform(tuple(transforms...))
or that would be breaking with theChainTransform(v, \theta)
constructor?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah, I think that would break that