Skip to content

Make pullback error for ColVecs and RowVecs a bit more informative #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

torfjelde
Copy link
Contributor

@torfjelde torfjelde commented Sep 12, 2023

The adjoint defined for ColVecs and RowVecs indicates that the cause of the error might be something internal to KernelFunctions.jl. But other packages are also making use of ColVecs and RowVecs, e.g. AbstractGPs.jl, which in turn means that issues often encountered in practice are not related to what the adjoint-error indicates, e.g. JuliaGaussianProcesses/AbstractGPs.jl#344

I ran into this yesterday and because of the error message I spent a fair bit of time looking for bugs in the kernel-related code rather than looking elsewhere. Hence the PR:)

My wording in the error message can probably do with an improvement?

EDIT: Oh and why does this adjoint definition exist? Why not just pull back the vector of vectors to a matrix? AFAIK the adjoint is still valid?

@codecov
Copy link

codecov bot commented Sep 12, 2023

Codecov Report

All modified lines are covered by tests ✅

Files Coverage Δ
src/chainrules.jl 45.67% <ø> (-23.46%) ⬇️

... and 22 files with indirect coverage changes

📢 Thoughts on this report? Let us know!.

@devmotion
Copy link
Member

EDIT: Oh and why does this adjoint definition exist? Why not just pull back the vector of vectors to a matrix? AFAIK the adjoint is still valid?

To avoid surprisingly slow code. ColVecs and RowVecs wrap a concatenated matrix of the vectors and usually it is faster to work with that underlying matrix instead of constructing a vector of vectors.

The code was introduced in #84 initially (moved to ChainRules in #208) but as it was copied from Stheno.jl maybe @willtebbutt has some additional comments.

@st--
Copy link
Member

st-- commented Sep 12, 2023

Looks good to me, can you also bump the patch version?

@torfjelde
Copy link
Contributor Author

To avoid surprisingly slow code.

That's what I suspected, but is it worth it given how difficult it can be to debug these adjoint issues for most users? 😬

@devmotion
Copy link
Member

Difficult to debug implies that probably it would be even more difficult to notice the problem and the cause for the slow performance without the error 😛

@torfjelde
Copy link
Contributor Author

True! But KernelFunctions.jl generally supports these slow-paths given that it's all defined on AbstractVector, no? And given that user-defined methods can easily touch this, e.g. through AbstractGPs.CustomMean, then this error breaks the entire call rather than just having the user-defined method be slow.

@torfjelde
Copy link
Contributor Author

I think the error makes complete sense in the scenario where all usages of ColVecs and RowVecs are internal, but that it becomes somewhat more nuanced when it can interact with external code.

@torfjelde
Copy link
Contributor Author

torfjelde commented Sep 12, 2023

For example, the issue referenced above, is caused because we end up silently taking the "slow path", but we're doing this because there exists a default implementation for AbstractGPs.mean_vector. So I guess my question is then boiled down to, why allow this slow default method to be hit but not the slow pull back?

EDIT: I realize AbstractGPs is a different package, but since it's under the same org + probably done by the same people, I'm guessing the decisions are somewhat related:)

@devmotion
Copy link
Member

Without looking into all details, if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs. But of course we want to be as generic as possible so we define methods for AbstractVector - we just always want to take the optimal path for ColVecs/RowVecs.

@willtebbutt
Copy link
Member

Without looking into all details, if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs. But of course we want to be as generic as possible so we define methods for AbstractVector - we just always want to take the optimal path for ColVecs/RowVecs.

I agree with this. I can definitely see you point @torfjelde , in that if a slow path is taken, it might be nice for the code not to fall over. Moreover, if the forwards-pass takes a slow path, then I am completely fine with the reverse-pass also taking a slow path. What I was primarily trying to guard against was the forwards-pass taking the fast path, and the pullback somehow hitting the slow path (I believe that this bit me several times when I hit a piratic rrule).

I'll add another comment to the PR, as I think we could clarify this even further.

@torfjelde
Copy link
Contributor Author

torfjelde commented Sep 12, 2023

if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs

Ah gotcha, then it makes sense 👍

What I was primarily trying to guard against was the forwards-pass taking the fast path, and the pullback somehow hitting the slow path (I believe that this bit me several times when I hit a piratic rrule).

Yeah I can see how an incorrect rrule could cause issues, so it def seems sensible 👍 To me the confusion was mainly when a forward-pass hits the slow path, then we should be happy with the reverse also taking the slow path.

But so this is in general a "restriction" with ColVecs and RowVecs then? You're not allowed to "automatically" AD through "slow paths", such as the example from AbstractGPs?

EDIT: As in, without either defining an explicit overload for the forward-pass or defining a custom adjoint.

@willtebbutt
Copy link
Member

But so this is in general a "restriction" with ColVecs and RowVecs then? You're not allowed to "automatically" AD through "slow paths", such as the example from AbstractGPs?

Depends on the kind of slow path. If when you AD the slow path, you produce a Vector{Vector{T}} cotangent for a ColVecs/RowVecs, then no. If it's a slow path, but you get the right type, it'll work fine.

@torfjelde
Copy link
Contributor Author

torfjelde commented Sep 13, 2023

If when you AD the slow path, you produce a Vector{Vector{T}} cotangent for a ColVecs/RowVecs, then no. If it's a slow path, but you get the right type, it'll work fine.

With "slow path" I mean when a method that works with AbstractVector or similar receives a ColVecs, e.g. map. In these scenarios, the evaluation/forward pass is allowed and works just fine, even though you're hitting a slow path, but (specifically reverse-mode) AD is not allowed. Have I understood correctly?

@willtebbutt
Copy link
Member

With "slow path" I mean when a method that works with AbstractVector or similar receives a ColVecs, e.g. map. In these scenarios, the evaluation/forward pass is allowed and works just fine, even though you're hitting a slow path, but (specifically reverse-mode) AD is not allowed. Have I understood correctly?

Kind of. I'm just saying it's more specific than that in that the reverse-pass must return a Vector{Vector{T}} rather than a Tangent{ColVecs{T}} or whatever. It happens to be the case that this tends to align with when the forwards-pass does something slow and annoying, but you could imagine situations in which this isn't the case. e.g. map(sum, ColVecs(randn(10, 20))) isn't going to be particularly slow. My point is that it's not so much to do with slowness, as it is type errors.

@torfjelde
Copy link
Contributor Author

It happens to be the case that this tends to align with when the forwards-pass does something slow and annoying, but you could imagine situations in which this isn't the case. e.g. map(sum, ColVecs(randn(10, 20))) isn't going to be particularly slow. My point is that it's not so much to do with slowness, as it is type errors.

Sure, that I'm with; I was referring to "slow paths" because I thought that was the original motivation as to why this rrule was implemented as is.

But is it understandable that it's somewhat confusing that something like map(sum, ColVecs(randn(10, 20))) is allowed but computing the pullback of this is not?

@willtebbutt
Copy link
Member

willtebbutt commented Sep 14, 2023

But is it understandable that it's somewhat confusing that something like map(sum, ColVecs(randn(10, 20))) is allowed but computing the pullback of this is not?

It's entirely understandable, but I felt (still feel) that it's the lesser of two evils -- it's opting for a loud error when performance is bad (read: catastrophic) on the reverse-pass, rather than allowing the code to run very slowly, but eventually produce an answer.

@torfjelde
Copy link
Contributor Author

Gotcha, gotcha 👍

Thanks for explaining the decision! I'm curious about this stuff partially because of the discussions we've had regarding more general batching structures in the past.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sorting this out @torfjelde . I'll merge when CI passes.

@willtebbutt
Copy link
Member

Something has broken with AD 🤦 . It's clearly not this PR's fault, so I'm going to merge and tag a release anyway.

@willtebbutt willtebbutt merged commit cf937ce into JuliaGaussianProcesses:master Sep 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants