Skip to content

Commit 2cadfa0

Browse files
committed
Limit scope of changes only to AutoDiff code
1 parent dd499e9 commit 2cadfa0

File tree

3 files changed

+15
-25
lines changed

3 files changed

+15
-25
lines changed

include/swift/AST/Expr.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4482,8 +4482,6 @@ class ClosureExpr : public AbstractClosureExpr {
44824482
class AutoClosureExpr : public AbstractClosureExpr {
44834483
BraceStmt *Body;
44844484

4485-
ApplyExpr *getUnwrappedCurryThunkImpl() const;
4486-
44874485
public:
44884486
enum class Kind : uint8_t {
44894487
// An autoclosure with type () -> Result. Formed from type checking an
@@ -4545,10 +4543,6 @@ class AutoClosureExpr : public AbstractClosureExpr {
45454543
/// - otherwise, returns nullptr for convenience.
45464544
Expr *getUnwrappedCurryThunkExpr() const;
45474545

4548-
/// Same as getUnwrappedCurryThunkExpr, but get the called ValueDecl instead
4549-
/// of the expr.
4550-
ValueDecl *getUnwrappedCurryThunkCalledValue() const;
4551-
45524546
// Implement isa/cast/dyncast/etc.
45534547
static bool classof(const Expr *E) {
45544548
return E->getKind() == ExprKind::AutoClosure;

lib/AST/Expr.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,7 @@ Expr *AutoClosureExpr::getSingleExpressionBody() const {
21452145
return cast<ReturnStmt>(Body->getLastElement().get<Stmt *>())->getResult();
21462146
}
21472147

2148-
ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
2148+
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
21492149
auto maybeUnwrapOpenExistential = [](Expr *expr) {
21502150
if (auto *openExistential = dyn_cast<OpenExistentialExpr>(expr)) {
21512151
expr = openExistential->getSubExpr()->getSemanticsProvidingExpr();
@@ -2189,7 +2189,7 @@ ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
21892189
body = maybeUnwrapConversions(body);
21902190

21912191
if (auto *outerCall = dyn_cast<ApplyExpr>(body)) {
2192-
return outerCall;
2192+
return outerCall->getFn();
21932193
}
21942194

21952195
assert(false && "Malformed curry thunk?");
@@ -2210,7 +2210,7 @@ ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
22102210
if (auto *outerCall = dyn_cast<ApplyExpr>(innerBody)) {
22112211
auto outerFn = maybeUnwrapConversions(outerCall->getFn());
22122212
if (auto *innerCall = dyn_cast<ApplyExpr>(outerFn)) {
2213-
return innerCall;
2213+
return innerCall->getFn();
22142214
}
22152215
}
22162216
}
@@ -2223,20 +2223,6 @@ ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
22232223
return nullptr;
22242224
}
22252225

2226-
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
2227-
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2228-
if (ae == nullptr)
2229-
return nullptr;
2230-
return ae->getFn();
2231-
}
2232-
2233-
ValueDecl *AutoClosureExpr::getUnwrappedCurryThunkCalledValue() const {
2234-
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2235-
if (ae == nullptr)
2236-
return nullptr;
2237-
return ae->getCalledValue(/*skipFunctionConversions=*/true);
2238-
}
2239-
22402226
FORWARD_SOURCE_LOCS_TO(UnresolvedPatternExpr, subPattern)
22412227

22422228
TypeExpr::TypeExpr(TypeRepr *Repr)

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,18 @@ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
485485
if (autoCE == nullptr)
486486
return nullptr;
487487

488-
auto *afd =
489-
cast<AbstractFunctionDecl>(autoCE->getUnwrappedCurryThunkCalledValue());
488+
AbstractFunctionDecl *afd = nullptr;
489+
Expr *expr = autoCE->getUnwrappedCurryThunkExpr();
490+
switch (autoCE->getThunkKind()) {
491+
default:
492+
llvm_unreachable("Only single and double curry thunks are expected");
493+
case AutoClosureExpr::Kind::SingleCurryThunk:
494+
case AutoClosureExpr::Kind::DoubleCurryThunk:
495+
afd = cast<AbstractFunctionDecl>(cast<ApplyExpr>(expr)->getCalledValue(
496+
/*skipFunctionConversions=*/true));
497+
break;
498+
}
499+
assert(afd);
490500

491501
auto silFnIt = afdToSILFn.find(afd);
492502
if (silFnIt == afdToSILFn.end()) {

0 commit comments

Comments
 (0)