Skip to content

Commit 53992d0

Browse files
committed
[ConstraintSystem] NFC: Extract async node identification into a function
It's useful not only for effect determination but for diagnostics as well.
1 parent c472ce4 commit 53992d0

File tree

2 files changed

+66
-52
lines changed

2 files changed

+66
-52
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5592,6 +5592,11 @@ Type getConcreteReplacementForProtocolSelfType(ValueDecl *member);
55925592
/// of operator overload choices.
55935593
bool isOperatorDisjunction(Constraint *disjunction);
55945594

5595+
/// Find out whether closure body has any `async` or `await` expressions,
5596+
/// declarations, or statements directly in its body (no in other closures
5597+
/// or nested declarations).
5598+
ASTNode findAsyncNode(ClosureExpr *closure);
5599+
55955600
} // end namespace constraints
55965601

55975602
template<typename ...Args>

lib/Sema/ConstraintSystem.cpp

Lines changed: 61 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,52 +2486,6 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
24862486
bool foundThrow() { return FoundThrow; }
24872487
};
24882488

2489-
// A walker that looks for 'async' and 'await' expressions
2490-
// that aren't nested within closures or nested declarations.
2491-
class FindInnerAsync : public ASTWalker {
2492-
bool FoundAsync = false;
2493-
2494-
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
2495-
// If we've found an 'await', record it and terminate the traversal.
2496-
if (isa<AwaitExpr>(expr)) {
2497-
FoundAsync = true;
2498-
return { false, nullptr };
2499-
}
2500-
2501-
// Do not recurse into other closures.
2502-
if (isa<ClosureExpr>(expr))
2503-
return { false, expr };
2504-
2505-
return { true, expr };
2506-
}
2507-
2508-
bool walkToDeclPre(Decl *decl) override {
2509-
// Do not walk into function or type declarations.
2510-
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
2511-
if (patternBinding->isSpawnLet())
2512-
FoundAsync = true;
2513-
2514-
return true;
2515-
}
2516-
2517-
return false;
2518-
}
2519-
2520-
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
2521-
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
2522-
if (forEach->getAwaitLoc().isValid()) {
2523-
FoundAsync = true;
2524-
return { false, nullptr };
2525-
}
2526-
}
2527-
2528-
return { true, stmt };
2529-
}
2530-
2531-
public:
2532-
bool foundAsync() { return FoundAsync; }
2533-
};
2534-
25352489
// If either 'throws' or 'async' was explicitly specified, use that
25362490
// set of effects.
25372491
bool throws = expr->getThrowsLoc().isValid();
@@ -2552,13 +2506,11 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
25522506

25532507
auto throwFinder = FindInnerThrows(*this, expr);
25542508
body->walk(throwFinder);
2555-
auto asyncFinder = FindInnerAsync();
2556-
body->walk(asyncFinder);
25572509
auto result = ASTExtInfoBuilder()
2558-
.withThrows(throwFinder.foundThrow())
2559-
.withAsync(asyncFinder.foundAsync())
2560-
.withConcurrent(concurrent)
2561-
.build();
2510+
.withThrows(throwFinder.foundThrow())
2511+
.withAsync(bool(findAsyncNode(expr)))
2512+
.withConcurrent(concurrent)
2513+
.build();
25622514
closureEffectsCache[expr] = result;
25632515
return result;
25642516
}
@@ -5708,3 +5660,60 @@ bool constraints::isOperatorDisjunction(Constraint *disjunction) {
57085660
auto *decl = getOverloadChoiceDecl(choices.front());
57095661
return decl ? decl->isOperator() : false;
57105662
}
5663+
5664+
ASTNode constraints::findAsyncNode(ClosureExpr *closure) {
5665+
auto *body = closure->getBody();
5666+
if (!body)
5667+
return ASTNode();
5668+
5669+
// A walker that looks for 'async' and 'await' expressions
5670+
// that aren't nested within closures or nested declarations.
5671+
class FindInnerAsync : public ASTWalker {
5672+
ASTNode AsyncNode;
5673+
5674+
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
5675+
// If we've found an 'await', record it and terminate the traversal.
5676+
if (isa<AwaitExpr>(expr)) {
5677+
AsyncNode = expr;
5678+
return {false, nullptr};
5679+
}
5680+
5681+
// Do not recurse into other closures.
5682+
if (isa<ClosureExpr>(expr))
5683+
return {false, expr};
5684+
5685+
return {true, expr};
5686+
}
5687+
5688+
bool walkToDeclPre(Decl *decl) override {
5689+
// Do not walk into function or type declarations.
5690+
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
5691+
if (patternBinding->isSpawnLet())
5692+
AsyncNode = patternBinding;
5693+
5694+
return true;
5695+
}
5696+
5697+
return false;
5698+
}
5699+
5700+
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
5701+
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
5702+
if (forEach->getAwaitLoc().isValid()) {
5703+
AsyncNode = forEach;
5704+
return {false, nullptr};
5705+
}
5706+
}
5707+
5708+
return {true, stmt};
5709+
}
5710+
5711+
public:
5712+
ASTNode getAsyncNode() { return AsyncNode; }
5713+
};
5714+
5715+
FindInnerAsync asyncFinder;
5716+
body->walk(asyncFinder);
5717+
5718+
return asyncFinder.getAsyncNode();
5719+
}

0 commit comments

Comments
 (0)