@@ -274,3 +274,81 @@ function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3))
274
274
end
275
275
end
276
276
end
277
+
278
+ function count_allocs (f, args... )
279
+ stats = @timed f (args... )
280
+ return Base. gc_alloc_count (stats. gcstats)
281
+ end
282
+
283
+ """
284
+ constant_allocs_heuristic(f, args1::T, args2::T) where {T}
285
+
286
+ True if number of allocations associated with evaluating `f(args1...)` is equal to those
287
+ required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related
288
+ allocations are not included.
289
+
290
+ Why is this a good test? In lots of situations it will be the case that the total amount of
291
+ memory allocated by a function will vary as the input sizes vary, but the total _number_
292
+ of allocations ought to be constant. A common performance bug is that the number of
293
+ allocations actually does scale with the size of the inputs (e.g. due to a type
294
+ instability), and we would very much like to know if this is happening.
295
+
296
+ Typically this kind of condition is not a sufficient condition for good performance, but it
297
+ is certainly a necessary condition.
298
+
299
+ This kind of test is very quick to conduct (just requires running `f` 4 times). It's also
300
+ easier to write than simply checking that the total number of allocations used to execute
301
+ a function is below some arbitrary `f`-dependent threshold.
302
+ """
303
+ function constant_allocs_heuristic (f, args1:: T , args2:: T ) where {T}
304
+
305
+ # Ensure that we're not counting allocations associated with compilation.
306
+ f (args1... )
307
+ f (args2... )
308
+
309
+ allocs_1 = count_allocs (f, args1... )
310
+ allocs_2 = count_allocs (f, args2... )
311
+ return allocs_1 == allocs_2
312
+ end
313
+
314
+ """
315
+ ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T}
316
+
317
+ Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its
318
+ pullback for both of `args1` and `args2`.
319
+
320
+ `Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)`
321
+ and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that
322
+ the output of the primal is an acceptable cotangent to be passed to the corresponding
323
+ pullback.
324
+ """
325
+ function ad_constant_allocs_heuristic (
326
+ f, args1:: T , args2:: T ; Δ1= nothing , Δ2= nothing
327
+ ) where {T}
328
+
329
+ # Check that primal has constant allocations.
330
+ primal_heuristic = constant_allocs_heuristic (f, args1, args2)
331
+
332
+ # Check that forwards-pass has constant allocations.
333
+ forwards_heuristic = constant_allocs_heuristic (
334
+ (args... ) -> Zygote. pullback (f, args... ), args1, args2
335
+ )
336
+
337
+ # Check that pullback has constant allocations for both arguments. Run twice to remove
338
+ # compilation-related allocations.
339
+
340
+ # First thing
341
+ out1, pb1 = Zygote. pullback (f, args1... )
342
+ Δ1_val = Δ1 === nothing ? out1 : Δ1
343
+ pb1 (Δ1_val)
344
+ allocs_1 = count_allocs (pb1, Δ1_val)
345
+
346
+ # Second thing
347
+ out2, pb2 = Zygote. pullback (f, args2... )
348
+ Δ2_val = Δ2 === nothing ? out2 : Δ2
349
+ pb2 (Δ2_val)
350
+ allocs_2 = count_allocs (pb2, Δ2 === nothing ? out2 : Δ2)
351
+
352
+ pullback_heuristic = allocs_1 == allocs_2
353
+ return primal_heuristic, forwards_heuristic, pullback_heuristic
354
+ end
0 commit comments