Skip to content

Commit 0c66055

Browse files
authored
[AutoDiff] Add #chainableGradient differential operator for seedable gradient (#18250)
* Add #chainableGradient differential operator for seedable gradient * Handle ExprKind::ChainableGradient.
1 parent f2baa0c commit 0c66055

File tree

11 files changed

+93
-18
lines changed

11 files changed

+93
-18
lines changed

include/swift/AST/Expr.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3818,7 +3818,8 @@ class DynamicTypeExpr : public Expr {
38183818
};
38193819

38203820
/// SWIFT_ENABLE_TENSORFLOW
3821-
/// Base class for #gradient and #valueAndGradient expressions.
3821+
/// Base class for differential operators, such as `#gradient`,
3822+
/// `#chainableGradient`, and `#valueAndGradient`.
38223823
class ReverseAutoDiffExpr : public Expr {
38233824
public:
38243825
Expr *getOriginalExpr() const {
@@ -3894,6 +3895,36 @@ class GradientExpr : public ReverseAutoDiffExpr {
38943895
: ReverseAutoDiffExpr(ExprKind::Gradient, loc, lParenLoc, originalExpr,
38953896
params, rParenLoc) {}
38963897
};
3898+
3899+
/// Chainable gradient expression - An expression that produces the
3900+
/// automatically differentiated function that computes the gradient (or
3901+
/// vector-Jacobian products) with respect to specified parameters, taking an
3902+
/// extra result-typed argument representing the seed, i.e. the backpropagated
3903+
/// adjoint.
3904+
/// Examples:
3905+
/// #chainableGradient(baz)
3906+
/// #chainableGradient(bar, wrt: .0, .1)
3907+
/// #chainableGradient(foo(_:_:), wrt: .0)
3908+
///
3909+
class ChainableGradientExpr : public ReverseAutoDiffExpr {
3910+
public:
3911+
static ChainableGradientExpr *create(ASTContext &ctx, SourceLoc loc,
3912+
SourceLoc lParenLoc, Expr *originalExpr,
3913+
ArrayRef<AutoDiffIndexParameter> parameters,
3914+
SourceLoc rParenLoc);
3915+
3916+
static bool classof(const Expr *E) {
3917+
return E->getKind() == ExprKind::ChainableGradient;
3918+
}
3919+
3920+
private:
3921+
explicit ChainableGradientExpr(SourceLoc loc, SourceLoc lParenLoc,
3922+
Expr *originalExpr,
3923+
ArrayRef<AutoDiffIndexParameter> params,
3924+
SourceLoc rParenLoc)
3925+
: ReverseAutoDiffExpr(ExprKind::ChainableGradient, loc, lParenLoc,
3926+
originalExpr, params, rParenLoc) {}
3927+
};
38973928

38983929
/// ValueAndGradient expression - An expression that produces an automatically
38993930
/// differentiated function that returns the result of the original function and

include/swift/AST/ExprNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ UNCHECKED_EXPR(KeyPathDot, Expr)
190190
// SWIFT_ENABLE_TENSORFLOW
191191
ABSTRACT_EXPR(ReverseAutoDiff, Expr)
192192
EXPR(Gradient, ReverseAutoDiffExpr)
193+
EXPR(ChainableGradient, ReverseAutoDiffExpr)
193194
EXPR(ValueAndGradient, ReverseAutoDiffExpr)
194195
EXPR(Adjoint, Expr)
195196
EXPR(PoundAssert, Expr)

include/swift/Syntax/TokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ POUND_OBJECT_LITERAL(tfop, "tfop", ExpressibleByTensorFlowOp)
286286

287287
// SWIFT_ENABLE_TENSORFLOW
288288
POUND_KEYWORD(gradient)
289+
POUND_KEYWORD(chainableGradient)
289290
POUND_KEYWORD(valueAndGradient)
290291
POUND_KEYWORD(adjoint)
291292
POUND_KEYWORD(assert)

lib/AST/ASTDumper.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,11 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
18571857
printCommon(E, "gradient_expr");
18581858
printReverseAutoDiffExpr(E);
18591859
}
1860+
1861+
void visitChainableGradientExpr(ChainableGradientExpr *E) {
1862+
printCommon(E, "chainable_gradient_expr");
1863+
printReverseAutoDiffExpr(E);
1864+
}
18601865

18611866
void visitValueAndGradientExpr(ValueAndGradientExpr *E) {
18621867
printCommon(E, "value_and_gradient_expr");

lib/AST/Expr.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ ConcreteDeclRef Expr::getReferencedDecl() const {
358358
NO_REFERENCE(KeyPathDot);
359359
// SWIFT_ENABLE_TENSORFLOW
360360
NO_REFERENCE(Gradient);
361+
NO_REFERENCE(ChainableGradient);
361362
NO_REFERENCE(ValueAndGradient);
362363
NO_REFERENCE(Adjoint);
363364
NO_REFERENCE(PoundAssert);
@@ -532,6 +533,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const {
532533
case ExprKind::KeyPath:
533534
// SWIFT_ENABLE_TENSORFLOW
534535
case ExprKind::Gradient:
536+
case ExprKind::ChainableGradient:
535537
case ExprKind::ValueAndGradient:
536538
case ExprKind::Adjoint:
537539
return true;
@@ -1142,21 +1144,34 @@ GradientExpr *GradientExpr::create(ASTContext &ctx, SourceLoc loc,
11421144
SourceLoc rParenLoc) {
11431145
unsigned numParams = parameters.size();
11441146
unsigned size =
1145-
sizeof(GradientExpr) + numParams * sizeof(AutoDiffIndexParameter);
1147+
sizeof(GradientExpr) + numParams * sizeof(AutoDiffIndexParameter);
11461148
void *memory = ctx.Allocate(size, alignof(GradientExpr));
11471149
return new (memory) GradientExpr(loc, lParenLoc, originalExpr, parameters,
11481150
rParenLoc);
11491151
}
11501152

11511153

1154+
ChainableGradientExpr *
1155+
ChainableGradientExpr::create(ASTContext &ctx, SourceLoc loc,
1156+
SourceLoc lParenLoc, Expr *originalExpr,
1157+
ArrayRef<AutoDiffIndexParameter> parameters,
1158+
SourceLoc rParenLoc) {
1159+
unsigned numParams = parameters.size();
1160+
unsigned size = sizeof(ChainableGradientExpr)
1161+
+ numParams * sizeof(AutoDiffIndexParameter);
1162+
void *memory = ctx.Allocate(size, alignof(ChainableGradientExpr));
1163+
return new (memory) ChainableGradientExpr(loc, lParenLoc, originalExpr,
1164+
parameters, rParenLoc);
1165+
}
1166+
11521167
ValueAndGradientExpr *
11531168
ValueAndGradientExpr::create(ASTContext &ctx, SourceLoc loc,
11541169
SourceLoc lParenLoc, Expr *originalExpr,
11551170
ArrayRef<AutoDiffIndexParameter> parameters,
11561171
SourceLoc rParenLoc) {
11571172
unsigned numParams = parameters.size();
11581173
unsigned size =
1159-
sizeof(ValueAndGradientExpr) + numParams * sizeof(AutoDiffIndexParameter);
1174+
sizeof(ValueAndGradientExpr) + numParams * sizeof(AutoDiffIndexParameter);
11601175
void *memory = ctx.Allocate(size, alignof(ValueAndGradientExpr));
11611176
return new (memory) ValueAndGradientExpr(loc, lParenLoc, originalExpr,
11621177
parameters, rParenLoc);

lib/Parse/ParseExpr.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1764,6 +1764,10 @@ ParserResult<Expr> Parser::parseExprPrimary(Diag<> ID, bool isExprBasic) {
17641764
case tok::pound_gradient:
17651765
return parseExprGradientBody(ExprKind::Gradient);
17661766
break;
1767+
1768+
case tok::pound_chainableGradient:
1769+
return parseExprGradientBody(ExprKind::ChainableGradient);
1770+
break;
17671771

17681772
case tok::pound_valueAndGradient:
17691773
return parseExprGradientBody(ExprKind::ValueAndGradient);
@@ -3663,7 +3667,8 @@ ParserResult<Expr> Parser::parseExprTypeOf() {
36633667
ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
36643668
SyntaxParsingContext GradientContext(SyntaxContext, SyntaxKind::GradientExpr);
36653669

3666-
assert(Tok.is(tok::pound_gradient) || Tok.is(tok::pound_valueAndGradient));
3670+
assert(Tok.isAny(tok::pound_gradient, tok::pound_valueAndGradient,
3671+
tok::pound_chainableGradient));
36673672
auto poundGradLoc = consumeToken();
36683673
SourceLoc lParenLoc;
36693674
SourceLoc rParenLoc;
@@ -3673,6 +3678,9 @@ ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
36733678
case ExprKind::Gradient:
36743679
exprName = "#gradient";
36753680
break;
3681+
case ExprKind::ChainableGradient:
3682+
exprName = "#chainableGradient";
3683+
break;
36763684
case ExprKind::ValueAndGradient:
36773685
exprName = "#valueAndGradient";
36783686
break;
@@ -3764,6 +3772,11 @@ ParserResult<Expr> Parser::parseExprGradientBody(ExprKind kind) {
37643772
originalFnParseResult.get(), params,
37653773
rParenLoc);
37663774
break;
3775+
case ExprKind::ChainableGradient:
3776+
result = ChainableGradientExpr::create(Context, poundGradLoc, lParenLoc,
3777+
originalFnParseResult.get(), params,
3778+
rParenLoc);
3779+
break;
37673780
case ExprKind::ValueAndGradient:
37683781
result = ValueAndGradientExpr::create(Context, poundGradLoc, lParenLoc,
37693782
originalFnParseResult.get(), params,

lib/SILGen/SILGenExpr.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ namespace {
480480
SGFContext C);
481481
// SWIFT_ENABLE_TENSORFLOW
482482
RValue visitGradientExpr(GradientExpr *E, SGFContext C);
483+
RValue visitChainableGradientExpr(ChainableGradientExpr *E, SGFContext C);
483484
RValue visitValueAndGradientExpr(ValueAndGradientExpr *E, SGFContext C);
484485
RValue visitAdjointExpr(AdjointExpr *E, SGFContext C);
485486
RValue visitObjectLiteralExpr(ObjectLiteralExpr *E, SGFContext C);
@@ -2824,6 +2825,11 @@ visitGradientExpr(GradientExpr *E, SGFContext C) {
28242825
return emitGradientInst(*this, C, E);
28252826
}
28262827

2828+
RValue RValueEmitter::
2829+
visitChainableGradientExpr(ChainableGradientExpr *E, SGFContext C) {
2830+
return emitGradientInst(*this, C, E, SILGradientFlags::Seedable);
2831+
}
2832+
28272833
RValue RValueEmitter::
28282834
visitValueAndGradientExpr(ValueAndGradientExpr *E, SGFContext C) {
28292835
return emitGradientInst(*this, C, E, SILGradientFlags::PreservingResult);

lib/Sema/CSApply.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,6 +2439,10 @@ namespace {
24392439
Expr *visitGradientExpr(GradientExpr *expr) {
24402440
return handleReverseAutoDiffExpr(expr, /*preservingOriginalResult=*/false);
24412441
}
2442+
2443+
Expr *visitChainableGradientExpr(ChainableGradientExpr *expr) {
2444+
llvm_unreachable("Unhandled");
2445+
}
24422446

24432447
Expr *visitValueAndGradientExpr(ValueAndGradientExpr *expr) {
24442448
return handleReverseAutoDiffExpr(expr, /*preservingOriginalResult=*/true);

lib/Sema/CSDiag.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,8 +1091,7 @@ class FailureDiagnosis :public ASTVisitor<FailureDiagnosis, /*exprresult*/bool>{
10911091
bool diagnoseSubscriptErrors(SubscriptExpr *SE, bool performingSet);
10921092

10931093
// SWIFT_ENABLE_TENSORFLOW
1094-
bool diagnoseReverseAutoDiffExpr(ReverseAutoDiffExpr *GE,
1095-
bool preservingPrimalResult);
1094+
bool diagnoseReverseAutoDiffExpr(ReverseAutoDiffExpr *GE);
10961095

10971096
/// Diagnose the usage of 'subscript' instead of the operator when calling
10981097
/// a subscript and offer a fixit if the inputs are compatible.
@@ -1123,8 +1122,7 @@ class FailureDiagnosis :public ASTVisitor<FailureDiagnosis, /*exprresult*/bool>{
11231122
bool visitClosureExpr(ClosureExpr *CE);
11241123
bool visitKeyPathExpr(KeyPathExpr *KPE);
11251124
// SWIFT_ENABLE_TENSORFLOW
1126-
bool visitGradientExpr(GradientExpr *GE);
1127-
bool visitValueAndGradientExpr(ValueAndGradientExpr *GE);
1125+
bool visitReverseAutoDiffExpr(ReverseAutoDiffExpr *RADE);
11281126
bool visitPoundAssertExpr(PoundAssertExpr *PAE);
11291127
};
11301128
} // end anonymous namespace
@@ -7471,11 +7469,10 @@ bool FailureDiagnosis::visitKeyPathExpr(KeyPathExpr *KPE) {
74717469

74727470
// SWIFT_ENABLE_TENSORFLOW
74737471
bool FailureDiagnosis::
7474-
diagnoseReverseAutoDiffExpr(ReverseAutoDiffExpr *GE,
7475-
bool preservingPrimalResult) {
7472+
diagnoseReverseAutoDiffExpr(ReverseAutoDiffExpr *RADE) {
74767473
// TODO: Sema diagnostics for gradient expressions could be improved by
74777474
// diagnosing non-differentiable arguments/non-differentiable constraints.
7478-
auto gradType = CS.getType(GE);
7475+
auto gradType = CS.getType(RADE);
74797476
auto gradFnType = gradType->getAs<AnyFunctionType>();
74807477
assert(gradFnType && "Gradient expression should have function type.");
74817478

@@ -7486,20 +7483,16 @@ diagnoseReverseAutoDiffExpr(ReverseAutoDiffExpr *GE,
74867483
// If gradient expression has a generic primal, then conversion to the
74877484
// contextual type was not possible.
74887485
if (gradType->hasTypeVariable()) {
7489-
diagnose(GE->getLoc(), diag::gradient_expr_incompatible_contextual_type,
7486+
diagnose(RADE->getLoc(), diag::gradient_expr_incompatible_contextual_type,
74907487
contextualType);
74917488
return true;
74927489
}
74937490

74947491
return false;
74957492
}
74967493

7497-
bool FailureDiagnosis::visitGradientExpr(GradientExpr *GE) {
7498-
return diagnoseReverseAutoDiffExpr(GE, /*preservingPrimalResult=*/false);
7499-
}
7500-
7501-
bool FailureDiagnosis::visitValueAndGradientExpr(ValueAndGradientExpr *GE) {
7502-
return diagnoseReverseAutoDiffExpr(GE, /*preservingPrimalResult=*/true);
7494+
bool FailureDiagnosis::visitReverseAutoDiffExpr(ReverseAutoDiffExpr *RADE) {
7495+
return diagnoseReverseAutoDiffExpr(RADE);
75037496
}
75047497

75057498
bool FailureDiagnosis::visitPoundAssertExpr(PoundAssertExpr *PAE) {

lib/Sema/CSGen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,10 @@ namespace {
13911391
Type visitGradientExpr(GradientExpr *GE) {
13921392
return handleReverseAutoDiffExpr(GE, /*preservingOriginalResult=*/false);
13931393
}
1394+
1395+
Type visitChainableGradientExpr(ChainableGradientExpr *CGE) {
1396+
llvm_unreachable("Unhandled");
1397+
}
13941398

13951399
Type visitValueAndGradientExpr(ValueAndGradientExpr *VGE) {
13961400
return handleReverseAutoDiffExpr(VGE, /*preservingOriginalResult=*/true);

test/AutoDiff/gradient_expr_parse.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#gradient(foo, wrt: .0, self) // expected-error {{expected a parameter, which must be }}
99
#gradient(foo, wrt: .0, .1) // okay
1010

11+
#chainableGradient(foo, wrt: .0, .1) // okay
12+
1113
#valueAndGradient(foo, wrt: .0, .1) // okay
1214

1315
#adjoint(foo(_:_:)) // okay

0 commit comments

Comments
 (0)