Skip to content

Commit 17dd2d0

Browse files
authored
Merge pull request #33408 from DougGregor/async-closures
[Concurrency] Add support for 'async' closures.
2 parents 893c6b2 + 5dd1bfe commit 17dd2d0

File tree

10 files changed

+150
-43
lines changed

10 files changed

+150
-43
lines changed

include/swift/AST/Expr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3815,6 +3815,9 @@ class ClosureExpr : public AbstractClosureExpr {
38153815
/// this information directly on the ClosureExpr.
38163816
VarDecl * CapturedSelfDecl;
38173817

3818+
/// The location of the "async", if present.
3819+
SourceLoc AsyncLoc;
3820+
38183821
/// The location of the "throws", if present.
38193822
SourceLoc ThrowsLoc;
38203823

@@ -3833,14 +3836,15 @@ class ClosureExpr : public AbstractClosureExpr {
38333836
llvm::PointerIntPair<BraceStmt *, 1, bool> Body;
38343837
public:
38353838
ClosureExpr(SourceRange bracketRange, VarDecl *capturedSelfDecl,
3836-
ParameterList *params, SourceLoc throwsLoc, SourceLoc arrowLoc,
3837-
SourceLoc inLoc, TypeExpr *explicitResultType,
3839+
ParameterList *params, SourceLoc asyncLoc, SourceLoc throwsLoc,
3840+
SourceLoc arrowLoc, SourceLoc inLoc, TypeExpr *explicitResultType,
38383841
unsigned discriminator, DeclContext *parent)
38393842
: AbstractClosureExpr(ExprKind::Closure, Type(), /*Implicit=*/false,
38403843
discriminator, parent),
38413844
BracketRange(bracketRange),
38423845
CapturedSelfDecl(capturedSelfDecl),
3843-
ThrowsLoc(throwsLoc), ArrowLoc(arrowLoc), InLoc(inLoc),
3846+
AsyncLoc(asyncLoc), ThrowsLoc(throwsLoc), ArrowLoc(arrowLoc),
3847+
InLoc(inLoc),
38443848
ExplicitResultTypeAndBodyState(explicitResultType, BodyState::Parsed),
38453849
Body(nullptr) {
38463850
setParameterList(params);
@@ -3888,7 +3892,12 @@ class ClosureExpr : public AbstractClosureExpr {
38883892
SourceLoc getInLoc() const {
38893893
return InLoc;
38903894
}
3891-
3895+
3896+
/// Retrieve the location of the 'async' for a closure that has it.
3897+
SourceLoc getAsyncLoc() const {
3898+
return AsyncLoc;
3899+
}
3900+
38923901
/// Retrieve the location of the 'throws' for a closure that has it.
38933902
SourceLoc getThrowsLoc() const {
38943903
return ThrowsLoc;

include/swift/AST/ExtInfo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,14 @@ class ASTExtInfo {
436436
return builder.withThrows(throws).build();
437437
}
438438

439+
/// Helper method for changing only the async field.
440+
///
441+
/// Prefer using \c ASTExtInfoBuilder::withAsync for chaining.
442+
LLVM_NODISCARD
443+
ASTExtInfo withAsync(bool async = true) const {
444+
return builder.withAsync(async).build();
445+
}
446+
439447
bool isEqualTo(ASTExtInfo other, bool useClangTypes) const {
440448
return builder.isEqualTo(other.builder, useClangTypes);
441449
}

include/swift/Parse/Parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,7 @@ class Parser {
15721572
SmallVectorImpl<CaptureListEntry> &captureList,
15731573
VarDecl *&capturedSelfParamDecl,
15741574
ParameterList *&params,
1575+
SourceLoc &asyncLoc,
15751576
SourceLoc &throwsLoc,
15761577
SourceLoc &arrowLoc,
15771578
TypeExpr *&explicitResultType,

lib/Parse/ParseExpr.cpp

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,7 +2432,8 @@ bool Parser::
24322432
parseClosureSignatureIfPresent(SourceRange &bracketRange,
24332433
SmallVectorImpl<CaptureListEntry> &captureList,
24342434
VarDecl *&capturedSelfDecl,
2435-
ParameterList *&params, SourceLoc &throwsLoc,
2435+
ParameterList *&params,
2436+
SourceLoc &asyncLoc, SourceLoc &throwsLoc,
24362437
SourceLoc &arrowLoc,
24372438
TypeExpr *&explicitResultType, SourceLoc &inLoc){
24382439
// Clear out result parameters.
@@ -2444,6 +2445,24 @@ parseClosureSignatureIfPresent(SourceRange &bracketRange,
24442445
explicitResultType = nullptr;
24452446
inLoc = SourceLoc();
24462447

2448+
// Consume 'async', 'throws', and 'rethrows', but in any order.
2449+
auto consumeAsyncThrows = [&] {
2450+
bool hadAsync = false;
2451+
if (Context.LangOpts.EnableExperimentalConcurrency &&
2452+
Tok.isContextualKeyword("async")) {
2453+
consumeToken();
2454+
hadAsync = true;
2455+
}
2456+
2457+
if (!consumeIf(tok::kw_throws) && !consumeIf(tok::kw_rethrows))
2458+
return;
2459+
2460+
if (Context.LangOpts.EnableExperimentalConcurrency && !hadAsync &&
2461+
Tok.isContextualKeyword("async")) {
2462+
consumeToken();
2463+
}
2464+
};
2465+
24472466
// If we have a leading token that may be part of the closure signature, do a
24482467
// speculative parse to validate it and look for 'in'.
24492468
if (Tok.isAny(tok::l_paren, tok::l_square, tok::identifier, tok::kw__)) {
@@ -2465,7 +2484,8 @@ parseClosureSignatureIfPresent(SourceRange &bracketRange,
24652484

24662485
// Consume the ')', if it's there.
24672486
if (consumeIf(tok::r_paren)) {
2468-
consumeIf(tok::kw_throws) || consumeIf(tok::kw_rethrows);
2487+
consumeAsyncThrows();
2488+
24692489
// Parse the func-signature-result, if present.
24702490
if (consumeIf(tok::arrow)) {
24712491
if (!canParseType())
@@ -2485,8 +2505,8 @@ parseClosureSignatureIfPresent(SourceRange &bracketRange,
24852505

24862506
return false;
24872507
}
2488-
2489-
consumeIf(tok::kw_throws) || consumeIf(tok::kw_rethrows);
2508+
2509+
consumeAsyncThrows();
24902510

24912511
// Parse the func-signature-result, if present.
24922512
if (consumeIf(tok::arrow)) {
@@ -2682,11 +2702,10 @@ parseClosureSignatureIfPresent(SourceRange &bracketRange,
26822702

26832703
params = ParameterList::create(Context, elements);
26842704
}
2685-
2686-
if (Tok.is(tok::kw_throws)) {
2687-
throwsLoc = consumeToken();
2688-
} else if (Tok.is(tok::kw_rethrows)) {
2689-
throwsLoc = consumeToken();
2705+
2706+
bool rethrows = false;
2707+
parseAsyncThrows(SourceLoc(), asyncLoc, throwsLoc, &rethrows);
2708+
if (rethrows) {
26902709
diagnose(throwsLoc, diag::rethrowing_function_type)
26912710
.fixItReplace(throwsLoc, "throws");
26922711
}
@@ -2803,13 +2822,14 @@ ParserResult<Expr> Parser::parseExprClosure() {
28032822
SmallVector<CaptureListEntry, 2> captureList;
28042823
VarDecl *capturedSelfDecl;
28052824
ParameterList *params = nullptr;
2825+
SourceLoc asyncLoc;
28062826
SourceLoc throwsLoc;
28072827
SourceLoc arrowLoc;
28082828
TypeExpr *explicitResultType;
28092829
SourceLoc inLoc;
2810-
parseClosureSignatureIfPresent(bracketRange, captureList,
2811-
capturedSelfDecl, params, throwsLoc,
2812-
arrowLoc, explicitResultType, inLoc);
2830+
parseClosureSignatureIfPresent(
2831+
bracketRange, captureList, capturedSelfDecl, params, asyncLoc, throwsLoc,
2832+
arrowLoc, explicitResultType, inLoc);
28132833

28142834
// If the closure was created in the context of an array type signature's
28152835
// size expression, there will not be a local context. A parse error will
@@ -2824,10 +2844,9 @@ ParserResult<Expr> Parser::parseExprClosure() {
28242844
unsigned discriminator = CurLocalContext->claimNextClosureDiscriminator();
28252845

28262846
// Create the closure expression and enter its context.
2827-
auto *closure = new (Context) ClosureExpr(bracketRange, capturedSelfDecl,
2828-
params, throwsLoc, arrowLoc, inLoc,
2829-
explicitResultType, discriminator,
2830-
CurDeclContext);
2847+
auto *closure = new (Context) ClosureExpr(
2848+
bracketRange, capturedSelfDecl, params, asyncLoc, throwsLoc, arrowLoc,
2849+
inLoc, explicitResultType, discriminator, CurDeclContext);
28312850
// The arguments to the func are defined in their own scope.
28322851
Scope S(this, ScopeKind::ClosureParams);
28332852
ParseFunctionBody cc(*this, closure);

lib/Sema/CSGen.cpp

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,9 +2114,7 @@ namespace {
21142114
}
21152115
}
21162116

2117-
auto extInfo = FunctionType::ExtInfo();
2118-
if (closureCanThrow(closure))
2119-
extInfo = extInfo.withThrows();
2117+
auto extInfo = closureEffects(closure);
21202118

21212119
// Closure expressions always have function type. In cases where a
21222120
// parameter or return type is omitted, a fresh type variable is used to
@@ -2596,9 +2594,12 @@ namespace {
25962594
return CS.getType(expr->getClosureBody());
25972595
}
25982596

2599-
/// Walk a closure AST to determine if it can throw.
2600-
bool closureCanThrow(ClosureExpr *expr) {
2601-
// A walker that looks for 'try' or 'throw' expressions
2597+
/// Walk a closure AST to determine its effects.
2598+
///
2599+
/// \returns a function's extended info describing the effects, as
2600+
/// determined syntactically.
2601+
FunctionType::ExtInfo closureEffects(ClosureExpr *expr) {
2602+
// A walker that looks for 'try' and 'throw' expressions
26022603
// that aren't nested within closures, nested declarations,
26032604
// or exhaustive catches.
26042605
class FindInnerThrows : public ASTWalker {
@@ -2743,18 +2744,62 @@ namespace {
27432744

27442745
bool foundThrow() { return FoundThrow; }
27452746
};
2746-
2747-
if (expr->getThrowsLoc().isValid())
2748-
return true;
2749-
2747+
2748+
// A walker that looks for 'async' and 'await' expressions
2749+
// that aren't nested within closures or nested declarations.
2750+
class FindInnerAsync : public ASTWalker {
2751+
bool FoundAsync = false;
2752+
2753+
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
2754+
// If we've found an 'await', record it and terminate the traversal.
2755+
if (isa<AwaitExpr>(expr)) {
2756+
FoundAsync = true;
2757+
return { false, nullptr };
2758+
}
2759+
2760+
// Do not recurse into other closures.
2761+
if (isa<ClosureExpr>(expr))
2762+
return { false, expr };
2763+
2764+
return { true, expr };
2765+
}
2766+
2767+
bool walkToDeclPre(Decl *decl) override {
2768+
// Do not walk into function or type declarations.
2769+
if (!isa<PatternBindingDecl>(decl))
2770+
return false;
2771+
2772+
return true;
2773+
}
2774+
2775+
public:
2776+
bool foundAsync() { return FoundAsync; }
2777+
};
2778+
2779+
// If either 'throws' or 'async' was explicitly specified, use that
2780+
// set of effects.
2781+
bool throws = expr->getThrowsLoc().isValid();
2782+
bool async = expr->getAsyncLoc().isValid();
2783+
if (throws || async) {
2784+
return ASTExtInfoBuilder()
2785+
.withThrows(throws)
2786+
.withAsync(async)
2787+
.build();
2788+
}
2789+
2790+
// Scan the body to determine the effects.
27502791
auto body = expr->getBody();
2751-
27522792
if (!body)
2753-
return false;
2754-
2755-
auto tryFinder = FindInnerThrows(CS, expr);
2756-
body->walk(tryFinder);
2757-
return tryFinder.foundThrow();
2793+
return FunctionType::ExtInfo();
2794+
2795+
auto throwFinder = FindInnerThrows(CS, expr);
2796+
body->walk(throwFinder);
2797+
auto asyncFinder = FindInnerAsync();
2798+
body->walk(asyncFinder);
2799+
return ASTExtInfoBuilder()
2800+
.withThrows(throwFinder.foundThrow())
2801+
.withAsync(asyncFinder.foundAsync())
2802+
.build();
27582803
}
27592804

27602805
Type visitClosureExpr(ClosureExpr *closure) {

lib/Sema/DebuggerTestingTransform.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ class DebuggerTestingTransform : public ASTWalker {
227227
auto *Params = ParameterList::createEmpty(Ctx);
228228
auto *Closure = new (Ctx)
229229
ClosureExpr(SourceRange(), nullptr, Params, SourceLoc(), SourceLoc(),
230-
SourceLoc(), nullptr, DF.getNextDiscriminator(),
231-
getCurrentDeclContext());
230+
SourceLoc(), SourceLoc(), nullptr,
231+
DF.getNextDiscriminator(), getCurrentDeclContext());
232232
Closure->setImplicit(true);
233233

234234
// TODO: Save and return the value of $OriginalExpr.

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,9 @@ deriveBodyDifferentiable_zeroTangentVectorInitializer(
393393

394394
auto *closureParams = ParameterList::createEmpty(C);
395395
auto *closure = new (C) ClosureExpr(
396-
SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(),
397-
SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C),
398-
discriminator, funcDecl);
396+
SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams,
397+
SourceLoc(), SourceLoc(), SourceLoc(), SourceLoc(),
398+
TypeExpr::createImplicit(resultTy, C), discriminator, funcDecl);
399399
closure->setImplicit();
400400
auto *closureReturn = new (C) ReturnStmt(SourceLoc(), zeroExpr, true);
401401
auto *closureBody =
@@ -504,8 +504,8 @@ deriveBodyDifferentiable_zeroTangentVectorInitializer(
504504
auto *closureParams = ParameterList::createEmpty(C);
505505
auto *closure = new (C) ClosureExpr(
506506
SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(),
507-
SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C),
508-
discriminator, funcDecl);
507+
SourceLoc(), SourceLoc(), SourceLoc(),
508+
TypeExpr::createImplicit(resultTy, C), discriminator, funcDecl);
509509
closure->setImplicit();
510510
auto *closureReturn = new (C) ReturnStmt(SourceLoc(), callExpr, true);
511511
auto *closureBody =

test/Parse/async-syntax.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,9 @@ func testTypeExprs() {
1414
func testAwaitOperator() async {
1515
let _ = __await asyncGlobal1()
1616
}
17+
18+
func testAsyncClosure() {
19+
let _ = { () async in 5 }
20+
let _ = { () throws in 5 }
21+
let _ = { () async throws in 5 }
22+
}

test/expr/unary/async_await.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,19 @@ func testAutoclosure() async {
4848
__await acceptAutoclosureNonAsync(getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
4949
// expected-warning@-1{{no calls to 'async' functions occur within 'await' expression}}
5050
}
51+
52+
// Test inference of 'async' from the body of a closure.
53+
func testClosure() {
54+
let closure = {
55+
__await getInt()
56+
}
57+
58+
let _: () -> Int = closure // expected-error{{cannot convert value of type '() async -> Int' to specified type '() -> Int'}}
59+
60+
let closure2 = { () async -> Int in
61+
print("here")
62+
return __await getInt()
63+
}
64+
65+
let _: () -> Int = closure2 // expected-error{{cannot convert value of type '() async -> Int' to specified type '() -> Int'}}
66+
}

utils/gyb_syntax_support/ExprNodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@
390390
Child('SimpleInput', kind='ClosureParamList'),
391391
Child('Input', kind='ParameterClause'),
392392
]),
393+
Child('AsyncKeyword', kind='IdentifierToken',
394+
classification='Keyword',
395+
text_choices=['async'], is_optional=True),
393396
Child('ThrowsTok', kind='ThrowsToken', is_optional=True),
394397
Child('Output', kind='ReturnClause', is_optional=True),
395398
Child('InTok', kind='InToken'),

0 commit comments

Comments
 (0)