Skip to content

ChainTransform AD performance #466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 22, 2022
Merged

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Aug 19, 2022

Summary

The ChainTransform has some performance issues on master.

Evidence:

using BenchmarkTools, KernelFunctions, Zygote

kernel(θ) = with_lengthscale(Matern12Kernel(), 0.5)  PeriodicTransform(θ)

foo(x) = KernelFunctions._map(PeriodicTransform(1 / 5), x)
bar(θ, x) = kernelmatrix(kernel(θ), x)

const x = randn(500);
out, pb = Zygote.pullback(bar, 5.0, x);

Δ = copy(out);
@benchmark $pb($Δ)

master:

BenchmarkTools.Trial: 36 samples with 1 evaluation.
 Range (min  max):  118.074 ms  234.881 ms  ┊ GC (min  max): 18.20%  24.77%
 Time  (median):     140.383 ms               ┊ GC (median):    18.11%
 Time  (mean ± σ):   140.637 ms ±  21.592 ms  ┊ GC (mean ± σ):  20.45% ±  4.18%

  █ █▃█ ▃     ██  ▃ ▃
  █▇███▇█▇▇▁▇▁██▇▇█▇█▇▁▇▇▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  118 ms           Histogram: frequency by time          235 ms <

 Memory estimate: 128.37 MiB, allocs estimate: 1586367.

This branch:

BenchmarkTools.Trial: 1946 samples with 1 evaluation.
 Range (min  max):  1.979 ms  10.005 ms  ┊ GC (min  max): 0.00%  77.81%
 Time  (median):     2.236 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.562 ms ±  1.175 ms  ┊ GC (mean ± σ):  8.04% ± 12.87%

  ██▇▆▆▅▃▃▁                                                  ▁
  ███████████▇▄▆▄▅▄▁▁▁▇▇▄▄▄▄▁▄▄▁▁▁▁▄▄▄▁▄▄▅▄▆▄▄▁▄▅▄▆▇▄▅▅▅▅▄▄▄ █
  1.98 ms      Histogram: log(frequency) by time     8.98 ms <

 Memory estimate: 7.70 MiB, allocs estimate: 272.

Proposed changes

  1. use a tuple rather than a vector to contain the things being chained together. This enables type-stable composition.
  2. call _map rather than map, because that's the API

Note that the way I'm testing that this change has been successful is by checking that the number of allocations required to compute the kernelmatrix, its forwards-pass and pullback (using Zygote) is invariant to the size of input vector considered. I plan to roll this out more widely in the coming days.

What alternatives have you considered?

None

Breaking changes

This only widens the set of permissible types in the ChainTransform, and which one gets used by default. On the basis of this, my inclination is to suggest that we shouldn't consider this breaking, but I might have missed something obvious.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Aug 19, 2022

Codecov Report

Merging #466 (aef58b9) into master (b5af459) will decrease coverage by 0.07%.
The diff coverage is 66.66%.

@@            Coverage Diff             @@
##           master     #466      +/-   ##
==========================================
- Coverage   93.16%   93.09%   -0.08%     
==========================================
  Files          52       52              
  Lines        1259     1275      +16     
==========================================
+ Hits         1173     1187      +14     
- Misses         86       88       +2     
Impacted Files Coverage Δ
src/transform/chaintransform.jl 80.00% <66.66%> (+1.73%) ⬆️
src/matrix/kernelpdmat.jl 75.00% <0.00%> (-6.82%) ⬇️
src/kernels/normalizedkernel.jl 80.00% <0.00%> (-2.36%) ⬇️
src/mokernels/lmm.jl 100.00% <0.00%> (ø)
src/kernels/kernelsum.jl 100.00% <0.00%> (ø)
src/kernels/kernelproduct.jl 100.00% <0.00%> (ø)
src/kernels/kerneltensorproduct.jl 98.85% <0.00%> (+0.08%) ⬆️
src/approximations/nystrom.jl 92.68% <0.00%> (+0.18%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@willtebbutt willtebbutt requested a review from theogf August 19, 2022 14:40
Copy link
Member

@theogf theogf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I also don't think it's breaking as it is

@@ -28,23 +28,23 @@ end
Base.length(t::ChainTransform) = length(t.transforms)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could add a constructor ChainTransform(transforms...) = ChainTransform(tuple(transforms...)) or that would be breaking with the ChainTransform(v, \theta) constructor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, I think that would break that

the output of the primal is an acceptable cotangent to be passed to the corresponding
pullback.
"""
function ad_constant_allocs_heuristic(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid I don't really understand why we want to check that the number of allocations is equal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. I should add more docs then. The logic is:

  1. if we've implemented this well, the number of allocations shouldn't change for different input sizes (although the total amount of memory allocated obviously will)
  2. a really common way for performance to be bad is some kind of type instability that introduces at least one allocation per kernelmatrix element
  3. We can measure the number of allocations really quickly

It's definitely not a sufficient condition for us to know that we've got good performance, but we at least know that the number of allocations is independent of the size of the output of kernelmatrix, which rules out a decent number of performance bugs.

Does this make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is much clearer thanks! If you put this in the docstring we are good to go :)

@willtebbutt
Copy link
Member Author

@theogf let me know whether my explanation of the testing is sufficient, and I'll add a docstring + merge

@willtebbutt
Copy link
Member Author

Will squash + merge when CI passes

@willtebbutt willtebbutt merged commit 1831cc6 into master Aug 22, 2022
@willtebbutt willtebbutt deleted the wct/transform-perfomrance branch August 22, 2022 12:11
@willtebbutt willtebbutt mentioned this pull request Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants