-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable. #27506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@bgogul I think you might have forgotten to handle the switch at: https://github.com/apple/swift/pull/27506/files#diff-a99d62dc86d23321183dc044483cf2caR808 |
I think the only thing we need to upstream in AutoDiff.h for parsing |
This PR introduces the `@transposing` attribute to mark functions as transposing other functions. This PR only contains changes related to parsing the attribute. Type checking and other changes will be added in subsequent patches. This work is related to the `@differentiable` attribute in #27506.
12c07b4
to
0429dbd
Compare
0429dbd
to
a9ac462
Compare
@DougGregor could you review this patch when you get a chance? |
4a150b1
to
f10459e
Compare
@swift-ci please smoke test |
1 similar comment
@swift-ci please smoke test |
@swift-ci please clean smoke test |
@swift-ci please smoke test linux |
@DougGregor, hold off review on this one. I will try to split it further so that it is a slightly more easier to review. |
51c32a8
to
608d6db
Compare
@swift-ci please smoke test |
include/swift/AST/Attr.def
Outdated
DECL_ATTR(differentiable, Differentiable, | ||
OnAccessor | OnConstructor | OnFunc | OnVar | OnSubscript | LongAttribute | | ||
AllowMultipleAttributes | | ||
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this affects ABI, so I'd recommend removing the "ABIStableToAdd | ABIStableToRemove
" tags, as well as APIStableToRemove
which seems inappropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me what these tags mean exactly, but I'll try the following configuration on tensorflow
branch to see if things break:
ABIBreakingToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove
Testing on tensorflow
branch in #28148. The PR description has more context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced the attributes to be similar to tensorflow
branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great with some changes requested, thanks!
include/swift/AST/AutoDiff.h
Outdated
#define SWIFT_AST_AUTODIFF_H | ||
|
||
#include "ASTContext.h" | ||
#include "llvm/ADT/SmallBitVector.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallBitVector.h looks unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is ASTContext.h necessary here? Can you forward declare or include smaller headers? ASTContext.h includes the world.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching these. It turned out, I only needed IndexSubset.h
instead of ASTContext.h
include/swift/AST/AutoDiff.h
Outdated
return V.Ordered.Index; | ||
} | ||
|
||
enum Kind getKind() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/enum Kind/Kind/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -907,6 +907,18 @@ bool Parser::parseMatchingToken(tok K, SourceLoc &TokLoc, Diag<> ErrorDiag, | |||
return false; | |||
} | |||
|
|||
bool Parser::parseUnsignedInteger(unsigned &Result, SourceLoc &Loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this new method? DAK_Alignment handling in ParseDecl.cpp does this:
StringRef alignmentText = Tok.getText();
unsigned alignmentValue;
if (alignmentText.getAsInteger(0, alignmentValue)) {
diagnose(Loc, diag::alignment_must_be_positive_integer);
return false;
}
Won't this work for you at the callsite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're trying to introduce a new generally useful utility, I'd recommend doing that as a separate patch which also migrates the existing callsites to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the code snippet similar to DAK_Alignment
in a bunch of places. I agree it is better to introduce this in a separate PR. I will do so shortly.
There are also more uses of parseUnsignedInteger
in the AD code base that has not been upstreamed yet.
lib/Parse/ParseDecl.cpp
Outdated
// Check that token after comma is 'wrt:' or a function specifier label. | ||
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" || | ||
Tok.getText() == "jvp" || | ||
Tok.getText() == "vjp")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it help to have a function like:
enum { wrt, jvp, vjp, invalid } classifyLabel(StringRef str);
function to keep all the classification logic in sync, and move the magic string parsing into a single place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
“jvp:” and “vjp:” will go away very soon with the almost-finished retroactive derivative registration feature (‘@differentiating’). So I think it is fine to leave them as is!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know, but it is important to keep master consistent and following best practices. If these aren't important, then it would be fine to remove them from the patch. If they need to be in the patch, then please consider implementing them in a nice way :).
I'm not saying that classifyLabel is an appropriate thing, just saying that the rationale doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a bunch of helpers. PTAL.
(Although other code in master is using Tok.getText()
to compare against literal strings.)
lib/Parse/ParseDecl.cpp
Outdated
funcDiag, /*allowOperators=*/true, | ||
/*allowZeroArgCompoundNames=*/true); | ||
// If no trailing comma or 'where' clause, terminate parsing arguments. | ||
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tok.isNot
takes multiple arguments, please use that instead of &&
here and anywhere else this comes up.
@@ -0,0 +1,180 @@ | |||
// RUN: %target-swift-frontend -parse -verify %s | |||
|
|||
/// Good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, thank you for the testcases!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, Chris!
@@ -907,6 +907,18 @@ bool Parser::parseMatchingToken(tok K, SourceLoc &TokLoc, Diag<> ErrorDiag, | |||
return false; | |||
} | |||
|
|||
bool Parser::parseUnsignedInteger(unsigned &Result, SourceLoc &Loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the code snippet similar to DAK_Alignment
in a bunch of places. I agree it is better to introduce this in a separate PR. I will do so shortly.
There are also more uses of parseUnsignedInteger
in the AD code base that has not been upstreamed yet.
include/swift/AST/AutoDiff.h
Outdated
return V.Ordered.Index; | ||
} | ||
|
||
enum Kind getKind() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
include/swift/AST/AutoDiff.h
Outdated
#define SWIFT_AST_AUTODIFF_H | ||
|
||
#include "ASTContext.h" | ||
#include "llvm/ADT/SmallBitVector.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching these. It turned out, I only needed IndexSubset.h
instead of ASTContext.h
lib/Parse/ParseDecl.cpp
Outdated
// Check that token after comma is 'wrt:' or a function specifier label. | ||
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" || | ||
Tok.getText() == "jvp" || | ||
Tok.getText() == "vjp")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a bunch of helpers. PTAL.
(Although other code in master is using Tok.getText()
to compare against literal strings.)
include/swift/AST/Attr.def
Outdated
DECL_ATTR(differentiable, Differentiable, | ||
OnAccessor | OnConstructor | OnFunc | OnVar | OnSubscript | LongAttribute | | ||
AllowMultipleAttributes | | ||
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced the attributes to be similar to tensorflow
branch.
LGTM. |
@swift-ci please smoke test |
1 similar comment
@swift-ci please smoke test |
@swift-ci please smoke test macOS |
@swift-ci please smoke test |
This broke the build. Please when doing these larger sorts of changes use a full test. |
Sorry about that! We'll be sure to use
I'm taking a look into the issue now. @gottesmm: would you prefer that we revert this PR in the meantime? EDIT: |
…attribute. Friend PR: swiftlang/swift#27506
Update gyb-generated files for `@differentiable` attribute. Friend PR: swiftlang/swift#27506
@dan-zheng @gottesmm PR size has nothing to do with it. Any functional change that could conceivably be sensitive to build configuration needs to run normal PR testing (swift-ci test). I always run normal testing on a functional change, but a lot of people aren't doing this. For a major functional change full PR testing would include benchmarks, SCK, and maybe even compiler performance tests. |
@dan-zheng I noticed today that this PR includes serialization for |
Thanks for noticing!
TF-836 tracks Small aside:
†: Serializing this is necessary to print deserialized It's worthwhile to discuss which attribute components should be serialized! I started a forum question for discussion. Precedent: |
Update gyb-generated files for `@differentiable` attribute. Friend PR: swiftlang/swift#27506
This PR introduces
@differentiable
attribute to mark functions as differentiable. This PR only contains changes related to parsing the attribute. Type checking and other changes will be added in subsequent patches.See https://github.com/apple/swift/pull/27506/files#diff-f3216f4188fd5ed34e1007e5a9c2490f for examples and tests for the new attribute.