Skip to content

Commit 7731c0e

Browse files
authored
Add libSyntax support for #chainableGradient. (#18255)
* Add #chainableGradient differential operator for seedable gradient * Handle ExprKind::ChainableGradient. * Add libSyntax support for #chainableGradient. * Rename GradientExpr to ReverseAutoDiffExpr for consistency with the class hierarchy.
1 parent 482690a commit 7731c0e

File tree

5 files changed

+22
-16
lines changed

5 files changed

+22
-16
lines changed

lib/Parse/ParseExpr.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3665,7 +3665,8 @@ ParserResult<Expr> Parser::parseExprTypeOf() {
36653665
/// '.' [0-9]+
36663666
///
36673667
ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
3668-
SyntaxParsingContext GradientContext(SyntaxContext, SyntaxKind::GradientExpr);
3668+
SyntaxParsingContext RADEContext(SyntaxContext,
3669+
SyntaxKind::ReverseAutoDiffExpr);
36693670

36703671
assert(Tok.isAny(tok::pound_gradient, tok::pound_valueAndGradient,
36713672
tok::pound_chainableGradient));
@@ -3727,7 +3728,7 @@ ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
37273728
// Function that parses one parameter.
37283729
auto parseParam = [&]() -> bool {
37293730
SyntaxParsingContext DiffParamContext(
3730-
SyntaxContext, SyntaxKind::GradientExprDiffParam);
3731+
SyntaxContext, SyntaxKind::ReverseAutoDiffExprParam);
37313732
SourceLoc paramLoc;
37323733
switch (Tok.getKind()) {
37333734
case tok::period_prefix: {
@@ -3757,7 +3758,8 @@ ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
37573758
while (Tok.isNot(tok::r_paren))
37583759
if (parseParam())
37593760
return errorAndSkipToEnd();
3760-
SyntaxContext->collectNodesInPlace(SyntaxKind::GradientExprParamList);
3761+
SyntaxContext->collectNodesInPlace(
3762+
SyntaxKind::ReverseAutoDiffExprParamList);
37613763
}
37623764
// Parse the closing ')'.
37633765
if (parseToken(tok::r_paren, rParenLoc, diag::expr_expected_rparen, exprName))

test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,12 @@ func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Num
561561
func bar<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
562562

563563
@differentiable(<DifferentiableAttributeArguments>reverse, <DifferentiableAttributeDiffParams>wrt: (<DifferentiableAttributeDiffParam>self, </DifferentiableAttributeDiffParam><DifferentiableAttributeDiffParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </DifferentiableAttributeDiffParam><DifferentiableAttributeDiffParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></DifferentiableAttributeDiffParam>), </DifferentiableAttributeDiffParams><DifferentiableAttributeFuncSpecifier>primal: bar, </DifferentiableAttributeFuncSpecifier><DifferentiableAttributeFuncSpecifier>adjoint: foo<DeclNameArguments>(<DeclNameArgument>_:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>) </DeclNameArguments></DifferentiableAttributeFuncSpecifier><GenericWhereClause>where <ConformanceRequirement><SimpleTypeIdentifier>T </SimpleTypeIdentifier>: <SimpleTypeIdentifier>FloatingPoint</SimpleTypeIdentifier></ConformanceRequirement></GenericWhereClause></DifferentiableAttributeArguments>)</Attribute>
564-
func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Numeric</SimpleTypeIdentifier></GenericParameter>></GenericParameterClause><FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>T</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>T</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>T </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><GradientExpr>
564+
func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Numeric</SimpleTypeIdentifier></GenericParameter>></GenericParameterClause><FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>T</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>T</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>T </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><ReverseAutoDiffExpr>
565565

566-
#gradient(<IdentifierExpr>foo</IdentifierExpr>)</GradientExpr><GradientExpr>
567-
#gradient(<IdentifierExpr>foo</IdentifierExpr>, wrt: <GradientExprDiffParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </GradientExprDiffParam><GradientExprDiffParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></GradientExprDiffParam>)</GradientExpr><GradientExpr>
568-
#valueAndGradient(<IdentifierExpr>foo</IdentifierExpr>, wrt: <GradientExprDiffParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </GradientExprDiffParam><GradientExprDiffParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></GradientExprDiffParam>)</GradientExpr><AdjointExpr>
566+
#gradient(<IdentifierExpr>foo</IdentifierExpr>)</ReverseAutoDiffExpr><ReverseAutoDiffExpr>
567+
#gradient(<IdentifierExpr>foo</IdentifierExpr>, wrt: <ReverseAutoDiffExprParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </ReverseAutoDiffExprParam><ReverseAutoDiffExprParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></ReverseAutoDiffExprParam>)</ReverseAutoDiffExpr><ReverseAutoDiffExpr>
568+
#chainableGradient(<IdentifierExpr>foo</IdentifierExpr>, wrt: <ReverseAutoDiffExprParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </ReverseAutoDiffExprParam><ReverseAutoDiffExprParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></ReverseAutoDiffExprParam>)</ReverseAutoDiffExpr><ReverseAutoDiffExpr>
569+
#valueAndGradient(<IdentifierExpr>foo</IdentifierExpr>, wrt: <ReverseAutoDiffExprParam><DifferentiationIndexParam>.0</DifferentiationIndexParam>, </ReverseAutoDiffExprParam><ReverseAutoDiffExprParam><DifferentiationIndexParam>.1</DifferentiationIndexParam></ReverseAutoDiffExprParam>)</ReverseAutoDiffExpr><AdjointExpr>
569570

570571
#adjoint(+)</AdjointExpr><AdjointExpr>
571572
#adjoint(foo<DeclNameArguments>(<DeclNameArgument>_:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments>)</AdjointExpr><AdjointExpr>

test/Syntax/round_trip_parse_gen.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ func bar<T : Numeric>(_ x: T, _: T) -> T { return 1 }
565565
566566
#gradient(foo)
567567
#gradient(foo, wrt: .0, .1)
568+
#chainableGradient(foo, wrt: .0, .1)
568569
#valueAndGradient(foo, wrt: .0, .1)
569570
570571
#adjoint(+)

utils/gyb_syntax_support/ExprNodes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -571,14 +571,14 @@
571571
]),
572572

573573
# SWIFT_ENABLE_TENSORFLOW
574-
# Expression generalizing #gradient and #valueAndGradient.
575-
# e.g. #gradient(foo(_:_:), wrt: .0, .1)
576-
Node('GradientExpr', kind='Expr',
574+
# e.g. "#gradient(foo(_:_:), wrt: .0, .1)"
575+
Node('ReverseAutoDiffExpr', kind='Expr',
577576
traits=['Parenthesized'],
578577
children=[
579578
Child('Identifier', kind='Token',
580579
token_choices=[
581580
'PoundGradientToken',
581+
'PoundChainableGradientToken',
582582
'PoundValueAndGradientToken',
583583
]),
584584
Child('LeftParen', kind='LeftParenToken'),
@@ -587,19 +587,19 @@
587587
Child('WrtLabel', kind='IdentifierToken', is_optional=True,
588588
text_choices=['wrt']),
589589
Child('Colon', kind='ColonToken', is_optional=True),
590-
Child('DiffParams', kind='GradientExprParamList',
590+
Child('DiffParams', kind='ReverseAutoDiffExprParamList',
591591
is_optional=True),
592592
Child('RightParen', kind='RightParenToken'),
593593
]),
594594

595-
# gradient-expr-diff-param-list ->
595+
# reverse-autodiff-expr-param-list ->
596596
# gradient-expr-diff-param gradient-expr-diff-param-list?
597-
Node('GradientExprParamList', kind='SyntaxCollection',
598-
element='GradientExprDiffParam'),
597+
Node('ReverseAutoDiffExprParamList', kind='SyntaxCollection',
598+
element='ReverseAutoDiffExprParam'),
599599

600-
# gradient-expr-diff-param ->
600+
# reverse-autodiff-expr-param ->
601601
# differentiation-index-param ','?
602-
Node('GradientExprDiffParam', kind='Syntax',
602+
Node('ReverseAutoDiffExprParam', kind='Syntax',
603603
description='''
604604
A differentiation parameter: a period followed by an unsigned integer \
605605
(e.g. `.0`).

utils/gyb_syntax_support/Token.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __init__(self, name, text):
125125
is_keyword=True),
126126
Token('PoundGradient', 'pound_gradient', text='#gradient',
127127
is_keyword=True),
128+
Token('PoundChainableGradient', 'pound_chainableGradient',
129+
text='#chainableGradient', is_keyword=True),
128130
Token('PoundValueAndGradient', 'pound_valueAndGradient',
129131
text='#valueAndGradient', is_keyword=True),
130132
Token('PoundAdjoint', 'pound_adjoint', text='#adjoint',

0 commit comments

Comments
 (0)