Skip to content

Commit 37b507b

Browse files
authored
[AutoDiff] Fix syntax support for @derivative and @transpose. (#28761)
Add `QualifiedDeclName` syntax node representing qualified declaration names with an optional base type. Add syntax support for `@transpose` attribute with qualified declaration. Fix trailing comma parsing for `DerivativeRegistrationAttributeArguments` syntax. Resolves TF-1009. Enables TF-1058: `@derivative` attribute with qualified declaration.
1 parent 39f3336 commit 37b507b

File tree

5 files changed

+94
-48
lines changed

5 files changed

+94
-48
lines changed

lib/Parse/ParseDecl.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ static bool errorAndSkipUntilConsumeRightParen(Parser &P, StringRef attrName,
877877
bool Parser::parseDifferentiationParametersClause(
878878
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {
879879
SyntaxParsingContext DiffParamsClauseContext(
880-
SyntaxContext, SyntaxKind::DifferentiationParamsClause);
880+
SyntaxContext, SyntaxKind::DifferentiationParamsClause);
881881
consumeToken(tok::identifier);
882882
if (!consumeIf(tok::colon)) {
883883
diagnose(Tok, diag::expected_colon_after_label, "wrt");
@@ -1207,17 +1207,19 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
12071207
}
12081208
{
12091209
// Parse the name of the function.
1210-
SyntaxParsingContext FuncDeclNameContext(SyntaxContext,
1211-
SyntaxKind::FunctionDeclName);
1210+
// TODO(TF-1058): Make `@derivative` attribute support qualified
1211+
// original declarations.
1212+
SyntaxParsingContext DeclNameContext(SyntaxContext,
1213+
SyntaxKind::QualifiedDeclName);
12121214
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to
12131215
// enable, e.g. `@derivative(of: init)` and `@derivative(of: subscript)`.
12141216
original.Name = parseUnqualifiedDeclName(
12151217
/*afterDot*/ true, original.Loc,
12161218
diag::attr_derivative_expected_original_name, /*allowOperators*/ true,
12171219
/*allowZeroArgCompoundNames*/ true, /*allowDeinitAndSubscript*/ true);
1218-
if (consumeIfTrailingComma())
1219-
return makeParserError();
12201220
}
1221+
if (consumeIfTrailingComma())
1222+
return makeParserError();
12211223
// Parse the optional 'wrt' differentiation parameters clause.
12221224
if (isIdentifier(Tok, "wrt") &&
12231225
parseDifferentiationParametersClause(params, AttrName))
@@ -1269,18 +1271,20 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
12691271
SyntaxKind::DeprecatedDerivativeRegistrationAttributeArguments);
12701272
{
12711273
// Parse the name of the function.
1272-
SyntaxParsingContext FuncDeclNameContext(
1273-
SyntaxContext, SyntaxKind::FunctionDeclName);
1274+
// TODO(TF-1058): Make `@differentiating` attribute support qualified
1275+
// original declarations.
1276+
SyntaxParsingContext DeclNameContext(SyntaxContext,
1277+
SyntaxKind::QualifiedDeclName);
12741278
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to
12751279
// enable, e.g. `@differentiating(init)` and
12761280
// `@differentiating(subscript)`.
12771281
original.Name = parseUnqualifiedDeclName(
12781282
/*afterDot*/ true, original.Loc,
12791283
diag::attr_derivative_expected_original_name, /*allowOperators*/ true,
12801284
/*allowZeroArgCompoundNames*/ true, /*allowDeinitAndSubscript*/ true);
1281-
if (consumeIfTrailingComma())
1282-
return makeParserError();
12831285
}
1286+
if (consumeIfTrailingComma())
1287+
return makeParserError();
12841288
// Parse the optional 'wrt' differentiation parameters clause.
12851289
if (isIdentifier(Tok, "wrt") &&
12861290
parseDifferentiationParametersClause(params, AttrName))
@@ -1335,6 +1339,8 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
13351339
/// Returns true on error (if function decl name could not be parsed).
13361340
bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
13371341
TypeRepr *&baseType, DeclNameWithLoc &original) {
1342+
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
1343+
SyntaxKind::QualifiedDeclName);
13381344
if (parseBaseTypeForQualifiedDeclName(P, baseType))
13391345
return true;
13401346

@@ -1399,18 +1405,13 @@ ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
13991405
}
14001406
{
14011407
// Parse the optionally qualified function name.
1402-
// TODO(TF-1009): Fix syntax support for dot-separated qualified names.
1403-
// Currently, `SyntaxKind::FunctionDeclName` only supports unqualified
1404-
// names.
1405-
SyntaxParsingContext FuncDeclNameContext(SyntaxContext,
1406-
SyntaxKind::FunctionDeclName);
14071408
if (parseQualifiedDeclName(*this,
14081409
diag::attr_transpose_expected_original_name,
14091410
baseType, original))
14101411
return makeParserError();
1411-
if (consumeIfTrailingComma())
1412-
return makeParserError();
14131412
}
1413+
if (consumeIfTrailingComma())
1414+
return makeParserError();
14141415
// Parse the optional 'wrt' transposed parameters clause.
14151416
if (Tok.is(tok::identifier) && Tok.getText() == "wrt" &&
14161417
parseTransposedParametersClause(params, AttrName))

test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -601,32 +601,35 @@ func bar<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleType
601601
@differentiable(<DifferentiableAttributeArguments><DifferentiationParamsClause>wrt: <DifferentiationParams>(<DifferentiationParam>self, </DifferentiationParam><DifferentiationParam>x, </DifferentiationParam><DifferentiationParam>y</DifferentiationParam>)</DifferentiationParams></DifferentiationParamsClause>, <DifferentiableAttributeFuncSpecifier>jvp: <FunctionDeclName>bar</FunctionDeclName>, </DifferentiableAttributeFuncSpecifier><DifferentiableAttributeFuncSpecifier>vjp: <FunctionDeclName>foo<DeclNameArguments>(<DeclNameArgument>_:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>) </DeclNameArguments></FunctionDeclName></DifferentiableAttributeFuncSpecifier><GenericWhereClause>where <GenericRequirement><ConformanceRequirement><SimpleTypeIdentifier>T </SimpleTypeIdentifier>: <SimpleTypeIdentifier>FloatingPoint</SimpleTypeIdentifier></ConformanceRequirement></GenericRequirement></GenericWhereClause></DifferentiableAttributeArguments>)</Attribute>
602602
func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Numeric</SimpleTypeIdentifier></GenericParameter>></GenericParameterClause><FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>T</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>T</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>T </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
603603

604-
@derivative(<DerivativeRegistrationAttributeArguments>of: <FunctionDeclName>-</FunctionDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
604+
@derivative(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>-</QualifiedDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
605605
func negateDerivative<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>)</ParameterClause><ReturnClause>
606606
-> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
607607
return <TupleExpr>(<TupleExprElement><PrefixOperatorExpr>-<IdentifierExpr>x</IdentifierExpr></PrefixOperatorExpr>, </TupleExprElement><TupleExprElement><ClosureExpr>{ <ClosureSignature><ClosureParam>v </ClosureParam>in </ClosureSignature><PrefixOperatorExpr>-<IdentifierExpr>v </IdentifierExpr></PrefixOperatorExpr>}</ClosureExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
608608
}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
609609

610-
@derivative(<DerivativeRegistrationAttributeArguments>of: <FunctionDeclName>baz<DeclNameArguments>(<DeclNameArgument>label:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments></FunctionDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
610+
@derivative(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>baz<DeclNameArguments>(<DeclNameArgument>label:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments></QualifiedDeclName>, <DifferentiationParamsClause>wrt: <DifferentiationParams>(<DifferentiationParam>x</DifferentiationParam>)</DifferentiationParams></DifferentiationParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
611611
func bazDerivative<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>)</ParameterClause><ReturnClause>
612-
-> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>)</TupleType></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
612+
-> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
613613
return <TupleExpr>(<TupleExprElement><IdentifierExpr>x</IdentifierExpr>, </TupleExprElement><TupleExprElement><ClosureExpr>{ <ClosureSignature><ClosureParam>v </ClosureParam>in </ClosureSignature><IdentifierExpr>v </IdentifierExpr>}</ClosureExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
614614
}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
615615

616-
@transpose(<DerivativeRegistrationAttributeArguments>of: <FunctionDeclName>+</FunctionDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
616+
@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>+</QualifiedDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
617617
func addTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
618618
return <TupleExpr>(<TupleExprElement><IdentifierExpr>v</IdentifierExpr>, </TupleExprElement><TupleExprElement><IdentifierExpr>v</IdentifierExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
619619
}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
620620

621-
@differentiating(<DeprecatedDerivativeRegistrationAttributeArguments><FunctionDeclName>baz<DeclNameArguments>(<DeclNameArgument>label:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments></FunctionDeclName></DeprecatedDerivativeRegistrationAttributeArguments>)</Attribute>
621+
@differentiating(<DeprecatedDerivativeRegistrationAttributeArguments><QualifiedDeclName>baz<DeclNameArguments>(<DeclNameArgument>label:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments></QualifiedDeclName>, <DifferentiationParamsClause>wrt: <DifferentiationParams>(<DifferentiationParam>x</DifferentiationParam>)</DifferentiationParams></DifferentiationParamsClause></DeprecatedDerivativeRegistrationAttributeArguments>)</Attribute>
622622
func bazDerivative<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>)</ParameterClause><ReturnClause>
623-
-> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>)</TupleType></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
623+
-> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
624624
return <TupleExpr>(<TupleExprElement><IdentifierExpr>x</IdentifierExpr>, </TupleExprElement><TupleExprElement><ClosureExpr>{ <ClosureSignature><ClosureParam>v </ClosureParam>in </ClosureSignature><IdentifierExpr>v </IdentifierExpr>}</ClosureExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
625-
}</CodeBlock></FunctionDecl>
625+
}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
626626

627-
// TODO(TF-1009): Add syntax support for dot-separated qualified names in
628-
// `@transpose(of:)` attributes.
629-
@transpose(of: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>.-)
627+
@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>-</QualifiedDeclName>, <DifferentiationParamsClause>wrt: <DifferentiationParams>(<DifferentiationParam>0, </DifferentiationParam><DifferentiationParam>1</DifferentiationParam>)</DifferentiationParams></DifferentiationParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
628+
func subtractTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
629+
return <TupleExpr>(<TupleExprElement><IdentifierExpr>v</IdentifierExpr>, </TupleExprElement><TupleExprElement><PrefixOperatorExpr>-<IdentifierExpr>v</IdentifierExpr></PrefixOperatorExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
630+
}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
631+
632+
@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>.-</QualifiedDeclName>, <DifferentiationParamsClause>wrt: <DifferentiationParam>0</DifferentiationParam></DifferentiationParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
630633
func negateTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
631634
return <PrefixOperatorExpr>-<IdentifierExpr>v</IdentifierExpr></PrefixOperatorExpr></ReturnStmt>
632-
}</CodeBlock>
635+
}</CodeBlock></FunctionDecl>

test/Syntax/round_trip_parse_gen.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,9 @@ func negateDerivative(_ x: Float)
607607
return (-x, { v in -v })
608608
}
609609

610-
@derivative(of: baz(label:_:))
610+
@derivative(of: baz(label:_:), wrt: (x))
611611
func bazDerivative(_ x: Float, y: Float)
612-
-> (value: Float, pullback: (Float) -> (Float, Float)) {
612+
-> (value: Float, pullback: (Float) -> Float) {
613613
return (x, { v in v })
614614
}
615615

@@ -618,15 +618,18 @@ func addTranspose(_ v: Float) -> (Float, Float) {
618618
return (v, v)
619619
}
620620

621-
@differentiating(baz(label:_:))
621+
@differentiating(baz(label:_:), wrt: (x))
622622
func bazDerivative(_ x: Float, y: Float)
623-
-> (value: Float, pullback: (Float) -> (Float, Float)) {
623+
-> (value: Float, pullback: (Float) -> Float) {
624624
return (x, { v in v })
625625
}
626626

627-
// TODO(TF-1009): Add syntax support for dot-separated qualified names in
628-
// `@transpose(of:)` attributes.
629-
@transpose(of: Float.-)
627+
@transpose(of: -, wrt: (0, 1))
628+
func subtractTranspose(_ v: Float) -> (Float, Float) {
629+
return (v, -v)
630+
}
631+
632+
@transpose(of: Float.-, wrt: 0)
630633
func negateTranspose(_ v: Float) -> Float {
631634
return -v
632635
}

0 commit comments

Comments
 (0)