Skip to content

[AutoDiff] Handle materializing adjoints with non-differentiable fields #67319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 12, 2023

Conversation

jkshtj
Copy link
Contributor

@jkshtj jkshtj commented Jul 14, 2023

This PR fixes a compiler crasher in AutoDiff.

The compiler used to crash while generating a pullback for differentiable functions that take a differentiable input, whose tangent vector contains non-differentiable fields.

Changes in this PR fix the issue by specializing adjoint materialization logic for user-defined tangent vectors containing non-differentiable fields.

Fixes #66522

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 14, 2023

@asl @BradLarson could one of you invoke swift-ci tests?

@asl
Copy link
Contributor

asl commented Jul 14, 2023

I am confused. I thought the discussion was that we need to fix adjoint generation for struct_extract instruction. Are there other usecases besides struct_extract that should be covered?

What is the test coverage for newly added code?

@asl
Copy link
Contributor

asl commented Jul 14, 2023

@swift-ci please test

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 14, 2023

I am confused. I thought the discussion was that we need to fix adjoint generation for struct_extract instruction.

I think I may have misunderstood how to go about fixing the issue. Are you suggesting that we never add non-differentiable fields to the adjoint for struct_extract?

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 14, 2023

Are there other usecases besides struct_extract that should be covered?

If we do make the fix in the adjoint generation code, then I believe we'd have to at least do the same for tuple_extract. There may be other instructions as well, but I'd need to take a closer look.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 14, 2023

What is the test coverage for newly added code?

In unit tests, none.

In the end to end tests, I added a tests for the path where we need to materialize an adjoint with non-differentiable fields. I should perhaps also have added a test for materializing normal adjoints with all differentiable fields.

@asl
Copy link
Contributor

asl commented Jul 14, 2023

I think I may have misunderstood how to go about fixing the issue. Are you suggesting that we never add non-differentiable fields to the adjoint for struct_extract?

There are multiple questions that need to be answered:

  1. Are there any other instructions that are affected here. Besides struct_extract? For example, you've added handling of adjoint buffers in addition to values. Are there any testcases for this code?

  2. Next, the current adjoint generation for struct_extract is essentially:

  ///   Original: y = struct_extract x, #field
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
  ///                                       ^~~~~~~
  ///                     field in tangent space corresponding to #field

So, instead of materalizing that bunch of zeros, why we simply cannot do the following (in struct_extract adjoint generation):

adj[x].field += adj[y]

After all, we know that things should conform to AdditiveArithmetic and therefore a += 0 should be equal to a.

@asl
Copy link
Contributor

asl commented Jul 14, 2023

In unit tests, none.

Please ensure that all codepaths are covered by tests.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 15, 2023

1. Are there any other instructions that are affected here. Besides `struct_extract`? For example, you've added handling of adjoint buffers in addition to values. Are there any testcases for this code?

Since the adjoint materialization code is shared by all instructions, we shouldn't need to handle things separately for different instructions. As for the handling of adjoint buffer handling, it looks like that code isn't actually called anywhere.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 15, 2023

2. Next, the current adjoint generation for `struct_extract` is essentially:
  ///   Original: y = struct_extract x, #field
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
  ///                                       ^~~~~~~
  ///                     field in tangent space corresponding to #field

So, instead of materalizing that bunch of zeros, why we simply cannot do the following:

adj[x].field += adj[y]

After all, we know that things should conform to AdditiveArithmetic and therefore a += 0 should be equal to a.

Ah, indeed this is what I tried to do first and I ran into an issue with TangentBuilder's emitInPlaceAdd method, which takes the lhs as a buffer. This was causing SIL verification failures for '+='-ing trivial types such as Double, with the following error message -

SIL verification failed: operand of 'apply' doesn't match function input type
  $Double
  $*Double
Verifying instruction:
   %0 = argument of bb0 : $Double                 // user: %8
     %5 = struct_element_addr %1 : $*P<Double>, #P.value // user: %8
     %6 = witness_method $Double, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %8
     %7 = metatype $@thick Double.Type            // user: %8
->   %8 = apply %6<Double>(%5, %0, %7) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()

It looks like I may have misinterpreted the meaning of the error, however. And I think I simply needed to change the type of the input arguments.

I can definitely try this out.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 15, 2023

In unit tests, none.

Please ensure that all codepaths are covered by tests.

Would you know of an existing unit test that tests auto diff code? I wasn't quite sure how I can setup PullbackCloner for unit testing.

@asl
Copy link
Contributor

asl commented Jul 15, 2023

1. Are there any other instructions that are affected here. Besides `struct_extract`? For example, you've added handling of adjoint buffers in addition to values. Are there any testcases for this code?

Since the adjoint materialization code is shared by all instructions, we shouldn't need to handle things separately for different instructions. As for the handling of adjoint buffer handling, it looks like that code isn't actually called anywhere.

Allright, but what are those "different instruction"? Besides tuple_extract? Looks like you're adding whole bunch of generic code without any tests just to fix one particular corner case. Differential paths are known to have less-than-optimal test coverage and this usually cause all kinds of issues when the code is actually called in some obscure case. So we try to improve the test coverage, not to reduce it.

@asl
Copy link
Contributor

asl commented Jul 15, 2023

Would you know of an existing unit test that tests auto diff code? I wasn't quite sure how I can setup PullbackCloner for unit testing.

test/AutoDiff contain the tests for different aspects of autodiff code.

@asl
Copy link
Contributor

asl commented Jul 15, 2023

This was causing SIL verification failures for '+='-ing trivial types such as Double, with the following error message -

The lhs is a struct somewhere. So you should be able to take the address of its field and pass to += method. However, in your particular case it was rhs that caused the assertion as the argument is passed indirect.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 17, 2023

Would you know of an existing unit test that tests auto diff code? I wasn't quite sure how I can setup PullbackCloner for unit testing.

test/AutoDiff contain the tests for different aspects of autodiff code.

Oh, I think I was mistaking unit tests for the ones that occur in the unittests directory in the Swift repo.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 17, 2023

This was causing SIL verification failures for '+='-ing trivial types such as Double, with the following error message -

The lhs is a struct somewhere. So you should be able to take the address of its field and pass to += method. However, in your particular case it was rhs that caused the assertion as the argument is passed indirect.

Yeah it was the right one. I was a little bit surprised because I was thinking that only the left one needs to be taken as a pointer (or by reference). But I'm guessing this more generalized SIL definition exists to accomodate for different types that may want to conform to AdditiveArithmetic?

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 17, 2023

1. Are there any other instructions that are affected here. Besides `struct_extract`? For example, you've added handling of adjoint buffers in addition to values. Are there any testcases for this code?

Since the adjoint materialization code is shared by all instructions, we shouldn't need to handle things separately for different instructions. As for the handling of adjoint buffer handling, it looks like that code isn't actually called anywhere.

Allright, but what are those "different instruction"? Besides tuple_extract? Looks like you're adding whole bunch of generic code without any tests just to fix one particular corner case. Differential paths are known to have less-than-optimal test coverage and this usually cause all kinds of issues when the code is actually called in some obscure case. So we try to improve the test coverage, not to reduce it.

I see the point. Indeed, I missed a number of tests that could be added with this change. I was thinking we can start by covering pullbacks for instructions listed here.

Can you think of any others that I should also test?

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 17, 2023

2. Next, the current adjoint generation for `struct_extract` is essentially:
  ///   Original: y = struct_extract x, #field
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
  ///                                       ^~~~~~~
  ///                     field in tangent space corresponding to #field

So, instead of materalizing that bunch of zeros, why we simply cannot do the following:

adj[x].field += adj[y]

After all, we know that things should conform to AdditiveArithmetic and therefore a += 0 should be equal to a.

@asl I've incorporated this feedback about adjoint materialization and am creating a new PR revision for some early feedback.

I'm still working on adding test coverage but am struggling a bit to understand what exactly the tests should be doing. I'll post my questions about tests in another comment shortly.

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 18, 2023

@rxwei Could you also please take a look at the approach and comment whether it makes sense and whether I am making changes at the right place?

@asl
Copy link
Contributor

asl commented Jul 18, 2023

@asl I've incorporated this feedback about adjoint materialization and am creating a new PR revision for some early feedback.

Have you forgotten to push your new changes?

@jkshtj
Copy link
Contributor Author

jkshtj commented Jul 18, 2023

@asl I've incorporated this feedback about adjoint materialization and am creating a new PR revision for some early feedback.

Have you forgotten to push your new changes?

I don't think so. I amended the existing commit, however, instead of creating a new one.

@jkshtj jkshtj requested a review from asl July 19, 2023 16:11
@jkshtj
Copy link
Contributor Author

jkshtj commented Sep 1, 2023

@rxwei @asl Thanks a lot for reviewing this PR so far!

I have fixed the nits. Could you please take a look?

@asl
Copy link
Contributor

asl commented Sep 1, 2023

@rxwei @asl Thanks a lot for reviewing this PR so far!

I have fixed the nits. Could you please take a look?

See new comment. Also, will you please squash the commits?

@jkshtj
Copy link
Contributor Author

jkshtj commented Sep 1, 2023

See new comment.
Did you forget to publish the comment 😅?

Also, will you please squash the commits?
Sure.

@@ -430,6 +430,8 @@ class PullbackCloner::Implementation final
}
case AdjointValueKind::AddElement: {
auto baseAdjoint = val;
assert(baseAdjoint.getType().is<TupleType>() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not what should be checked here. This assertion is enforced during AddElement construction. Here you need to ensure that all nested adjoints are of the same kind: either all tuples or all structs, but not the mixture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Sorry I sent out that last revision in haste. Fixed in the latest revision and squashed commits.

addEltAdjValues.push_back(addElementValue);
baseAdjoint = addElementValue->baseAdjoint;
assert(baseAdjointType == baseAdjoint.getType());
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for this loop, btw? Maybe I'm missing something, but all loops are doing single struct/tuple extract?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not explicit tests but the other tests I have added do use this code path -- I needed to fix new test failures after making these changes. But I think it's a good idea to add explicit tests for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asl I have added tests exercising the nested add element adjoint materialization path.

Copy link
Contributor

@asl asl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@asl
Copy link
Contributor

asl commented Sep 7, 2023

@swift-ci please test

@jkshtj jkshtj force-pushed the main branch 3 times, most recently from 0cf3dbb to 9343748 Compare September 10, 2023 05:16
@jkshtj
Copy link
Contributor Author

jkshtj commented Sep 10, 2023

@asl Had to make some changes after I discovered some small bugs while compiling and running internal test suite. I have added tests for these cases as well.

@jkshtj jkshtj requested a review from asl September 10, 2023 05:19
During internal testing we discovered 2 more bugs -
1. The element adjoint of a struct_extract can itself be an AddElement.
2. Indirect concrete adjoint materialization was missing a copy operation.

This commit fixes these bugs and adds relevant test cases.
Copy link
Contributor

@asl asl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

@asl
Copy link
Contributor

asl commented Sep 10, 2023

@swift-ci please test

@jkshtj
Copy link
Contributor Author

jkshtj commented Sep 10, 2023

@asl looks like the "Swift Test Linux Platform" failed due to some network issues. Can we restart them?

@BradLarson
Copy link
Contributor

@swift-ci Please test Linux platform

@asl
Copy link
Contributor

asl commented Sep 11, 2023

@swift-ci please test linux

@jkshtj
Copy link
Contributor Author

jkshtj commented Sep 11, 2023

Huh I see the following error due to which the Linux tests failed -

ERROR: Error fetching remote repo 'origin'

Seems like something may be up with the Linux build hosts?

@asl
Copy link
Contributor

asl commented Sep 11, 2023

@shahmishal Is CI down?

@asl
Copy link
Contributor

asl commented Sep 12, 2023

@swift-ci please test linux

@asl asl merged commit d971f12 into swiftlang:main Sep 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Autodiff]: Failure to synthesize conformance to AdditiveArithmetic in some cases.
4 participants