Skip to content

Commit 40c12cc

Browse files
committed
[Concurrency] Implement restrictions on calls to 'async' functions.
Implement missing restrictions on calls to 'async': * Diagnose async calls/uses of await in illegal contexts (such as default arguments) * Diagnose async calls/uses of await in functions/closures that are not asynchronous themselves * Handle autoclosure arguments as their own separate contexts (so 'await' has to go on the argument), which differs from error handling (where the 'try' can go outside) because we want to be more particular about marking the specific suspension points.
1 parent 6d8a50d commit 40c12cc

File tree

3 files changed

+153
-22
lines changed

3 files changed

+153
-22
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4038,8 +4038,24 @@ NOTE(note_disable_error_propagation,none,
40384038
"did you mean to disable error propagation?", ())
40394039
ERROR(async_call_without_await,none,
40404040
"call is 'async' but is not marked with 'await'", ())
4041+
ERROR(async_call_without_await_in_autoclosure,none,
4042+
"call is 'async' in an autoclosure argument is not marked with 'await'", ())
40414043
WARNING(no_async_in_await,none,
40424044
"no calls to 'async' functions occur within 'await' expression", ())
4045+
ERROR(async_call_in_illegal_context,none,
4046+
"'async' call cannot occur in "
4047+
"%select{<<ERROR>>|a default argument|a property initializer|a global variable initializer|an enum case raw value|a catch pattern|a catch guard expression|a defer body}0",
4048+
(unsigned))
4049+
ERROR(await_in_illegal_context,none,
4050+
"'await' operation cannot occur in "
4051+
"%select{<<ERROR>>|a default argument|a property initializer|a global variable initializer|an enum case raw value|a catch pattern|a catch guard expression|a defer body}0",
4052+
(unsigned))
4053+
ERROR(async_in_nonasync_function,none,
4054+
"%select{'async'|'await'}0 in %select{a function|an autoclosure}1 that "
4055+
"does not support concurrency",
4056+
(bool, bool))
4057+
NOTE(note_add_async_to_function,none,
4058+
"add 'async' to function %0 to make it asynchronous", (DeclName))
40434059

40444060
WARNING(no_throw_in_try,none,
40454061
"no calls to throwing functions occur within 'try' expression", ())

lib/Sema/TypeCheckEffects.cpp

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ class Context {
857857
Kind TheKind;
858858
Optional<AnyFunctionRef> Function;
859859
bool HandlesErrors = false;
860+
bool HandlesAsync = false;
860861

861862
/// Whether error-handling queries should ignore the function context, e.g.,
862863
/// for autoclosure and rethrows checks.
@@ -870,9 +871,10 @@ class Context {
870871
assert(TheKind != Kind::PotentiallyHandled);
871872
}
872873

873-
explicit Context(bool handlesErrors, Optional<AnyFunctionRef> function)
874+
explicit Context(bool handlesErrors, bool handlesAsync,
875+
Optional<AnyFunctionRef> function)
874876
: TheKind(Kind::PotentiallyHandled), Function(function),
875-
HandlesErrors(handlesErrors) { }
877+
HandlesErrors(handlesErrors), HandlesAsync(handlesAsync) { }
876878

877879
public:
878880
/// Whether this is a function that rethrows.
@@ -910,7 +912,7 @@ class Context {
910912

911913
static Context forTopLevelCode(TopLevelCodeDecl *D) {
912914
// Top-level code implicitly handles errors and 'async' calls.
913-
return Context(/*handlesErrors=*/true, None);
915+
return Context(/*handlesErrors=*/true, /*handlesAsync=*/true, None);
914916
}
915917

916918
static Context forFunction(AbstractFunctionDecl *D) {
@@ -930,8 +932,7 @@ class Context {
930932
}
931933
}
932934

933-
bool handlesErrors = D->hasThrows();
934-
return Context(handlesErrors, AnyFunctionRef(D));
935+
return Context(D->hasThrows(), D->hasAsync(), AnyFunctionRef(D));
935936
}
936937

937938
static Context forDeferBody() {
@@ -956,12 +957,15 @@ class Context {
956957
static Context forClosure(AbstractClosureExpr *E) {
957958
// Determine whether the closure has throwing function type.
958959
bool closureTypeThrows = true;
960+
bool closureTypeIsAsync = true;
959961
if (auto closureType = E->getType()) {
960-
if (auto fnType = closureType->getAs<AnyFunctionType>())
962+
if (auto fnType = closureType->getAs<AnyFunctionType>()) {
961963
closureTypeThrows = fnType->isThrowing();
964+
closureTypeIsAsync = fnType->isAsync();
965+
}
962966
}
963967

964-
return Context(closureTypeThrows, AnyFunctionRef(E));
968+
return Context(closureTypeThrows, closureTypeIsAsync, AnyFunctionRef(E));
965969
}
966970

967971
static Context forCatchPattern(CaseStmt *S) {
@@ -1013,6 +1017,10 @@ class Context {
10131017
llvm_unreachable("bad error kind");
10141018
}
10151019

1020+
bool handlesAsync() const {
1021+
return HandlesAsync;
1022+
}
1023+
10161024
DeclContext *getRethrowsDC() const {
10171025
if (!isRethrows())
10181026
return nullptr;
@@ -1182,7 +1190,6 @@ class Context {
11821190
case Kind::DeferBody:
11831191
diagnoseThrowInIllegalContext(Diags, E, getKind());
11841192
return;
1185-
11861193
}
11871194
llvm_unreachable("bad context kind");
11881195
}
@@ -1211,6 +1218,64 @@ class Context {
12111218
}
12121219
llvm_unreachable("bad context kind");
12131220
}
1221+
1222+
void diagnoseUncoveredAsyncSite(ASTContext &ctx, ASTNode node) {
1223+
SourceRange highlight;
1224+
1225+
// Generate more specific messages in some cases.
1226+
if (auto apply = dyn_cast_or_null<ApplyExpr>(node.dyn_cast<Expr*>()))
1227+
highlight = apply->getSourceRange();
1228+
1229+
auto diag = diag::async_call_without_await;
1230+
if (isAutoClosure())
1231+
diag = diag::async_call_without_await_in_autoclosure;
1232+
ctx.Diags.diagnose(node.getStartLoc(), diag)
1233+
.highlight(highlight);
1234+
}
1235+
1236+
void diagnoseAsyncInIllegalContext(DiagnosticEngine &Diags, ASTNode node) {
1237+
if (auto *e = node.dyn_cast<Expr*>()) {
1238+
if (isa<ApplyExpr>(e)) {
1239+
Diags.diagnose(e->getLoc(), diag::async_call_in_illegal_context,
1240+
static_cast<unsigned>(getKind()));
1241+
return;
1242+
}
1243+
}
1244+
1245+
Diags.diagnose(node.getStartLoc(), diag::await_in_illegal_context,
1246+
static_cast<unsigned>(getKind()));
1247+
}
1248+
1249+
void maybeAddAsyncNote(DiagnosticEngine &Diags) {
1250+
if (!Function)
1251+
return;
1252+
1253+
auto func = dyn_cast_or_null<FuncDecl>(Function->getAbstractFunctionDecl());
1254+
if (!func)
1255+
return;
1256+
1257+
func->diagnose(diag::note_add_async_to_function, func->getName());
1258+
}
1259+
1260+
void diagnoseUnhandledAsyncSite(DiagnosticEngine &Diags, ASTNode node) {
1261+
switch (getKind()) {
1262+
case Kind::PotentiallyHandled:
1263+
Diags.diagnose(node.getStartLoc(), diag::async_in_nonasync_function,
1264+
node.isExpr(ExprKind::Await), isAutoClosure());
1265+
maybeAddAsyncNote(Diags);
1266+
return;
1267+
1268+
case Kind::EnumElementInitializer:
1269+
case Kind::GlobalVarInitializer:
1270+
case Kind::IVarInitializer:
1271+
case Kind::DefaultArgument:
1272+
case Kind::CatchPattern:
1273+
case Kind::CatchGuard:
1274+
case Kind::DeferBody:
1275+
diagnoseAsyncInIllegalContext(Diags, node);
1276+
return;
1277+
}
1278+
}
12141279
};
12151280

12161281
/// A class to walk over a local context and validate the correctness
@@ -1322,6 +1387,12 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
13221387
Self.MaxThrowingKind = ThrowingKind::None;
13231388
}
13241389

1390+
void resetCoverageForAutoclosureBody() {
1391+
Self.Flags.clear(ContextFlags::IsAsyncCovered);
1392+
Self.Flags.clear(ContextFlags::HasAnyAsyncSite);
1393+
Self.Flags.clear(ContextFlags::HasAnyAwait);
1394+
}
1395+
13251396
void resetCoverageForDoCatch() {
13261397
Self.Flags.reset();
13271398
Self.MaxThrowingKind = ThrowingKind::None;
@@ -1409,6 +1480,7 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
14091480
ShouldRecurse_t checkAutoClosure(AutoClosureExpr *E) {
14101481
ContextScope scope(*this, Context::forClosure(E));
14111482
scope.enterSubFunction();
1483+
scope.resetCoverageForAutoclosureBody();
14121484
E->getBody()->walk(*this);
14131485
scope.preserveCoverageFromAutoclosureBody();
14141486
return ShouldNotRecurse;
@@ -1572,17 +1644,14 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
15721644
if (classification.isAsync()) {
15731645
// Remember that we've seen an async call.
15741646
Flags.set(ContextFlags::HasAnyAsyncSite);
1575-
1647+
1648+
// Diagnose async calls in a context that doesn't handle async.
1649+
if (!CurContext.handlesAsync()) {
1650+
CurContext.diagnoseUnhandledAsyncSite(Ctx.Diags, E);
1651+
}
15761652
// Diagnose async calls that are outside of an await context.
1577-
if (!Flags.has(ContextFlags::IsAsyncCovered)) {
1578-
SourceRange highlight;
1579-
1580-
// Generate more specific messages in some cases.
1581-
if (auto e = dyn_cast_or_null<ApplyExpr>(E.dyn_cast<Expr*>()))
1582-
highlight = e->getSourceRange();
1583-
1584-
Ctx.Diags.diagnose(E.getStartLoc(), diag::async_call_without_await)
1585-
.highlight(highlight);
1653+
else if (!Flags.has(ContextFlags::IsAsyncCovered)) {
1654+
CurContext.diagnoseUncoveredAsyncSite(Ctx, E);
15861655
}
15871656
}
15881657

@@ -1626,10 +1695,16 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
16261695
scope.enterAwait();
16271696

16281697
E->getSubExpr()->walk(*this);
1629-
1630-
// Warn about 'await' expressions that weren't actually needed.
1631-
if (!Flags.has(ContextFlags::HasAnyAsyncSite))
1632-
Ctx.Diags.diagnose(E->getAwaitLoc(), diag::no_async_in_await);
1698+
1699+
// Warn about 'await' expressions that weren't actually needed, unless of
1700+
// course we're in a context that could never handle an 'async'. Then, we
1701+
// produce an error.
1702+
if (!Flags.has(ContextFlags::HasAnyAsyncSite)) {
1703+
if (CurContext.handlesAsync())
1704+
Ctx.Diags.diagnose(E->getAwaitLoc(), diag::no_async_in_await);
1705+
else
1706+
CurContext.diagnoseUnhandledAsyncSite(Ctx.Diags, E);
1707+
}
16331708

16341709
// Inform the parent of the walk that an 'await' exists here.
16351710
scope.preserveCoverageFromAwaitOperand();

test/expr/unary/async_await.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,43 @@ func test1(asyncfp : () async -> Int, fp : () -> Int) async {
88
_ = asyncfp() // expected-error {{call is 'async' but is not marked with 'await'}}
99
}
1010

11+
func getInt() async -> Int { return 5 }
12+
13+
// Locations where "await" is prohibited.
14+
func test2(
15+
defaulted: Int = __await getInt() // expected-error{{'async' call cannot occur in a default argument}}
16+
) async {
17+
defer {
18+
_ = __await getInt() // expected-error{{'async' call cannot occur in a defer body}}
19+
}
20+
print("foo")
21+
}
22+
23+
func test3() { // expected-note{{add 'async' to function 'test3()' to make it asynchronous}}
24+
_ = __await getInt() // expected-error{{'async' in a function that does not support concurrency}}
25+
}
26+
27+
enum SomeEnum: Int {
28+
case foo = __await 5 // expected-error{{raw value for enum case must be a literal}}
29+
}
30+
31+
struct SomeStruct {
32+
var x = __await getInt() // expected-error{{'async' call cannot occur in a property initializer}}
33+
static var y = __await getInt() // expected-error{{'async' call cannot occur in a global variable initializer}}
34+
}
35+
36+
func acceptAutoclosureNonAsync(_: @autoclosure () -> Int) { }
37+
func acceptAutoclosureAsync(_: @autoclosure () async -> Int) { }
38+
39+
func testAutoclosure() async {
40+
acceptAutoclosureAsync(getInt()) // expected-error{{call is 'async' in an autoclosure argument is not marked with 'await'}}
41+
acceptAutoclosureNonAsync(getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
42+
43+
acceptAutoclosureAsync(__await getInt())
44+
acceptAutoclosureNonAsync(__await getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
45+
46+
__await acceptAutoclosureAsync(getInt()) // expected-error{{call is 'async' in an autoclosure argument is not marked with 'await'}}
47+
// expected-warning@-1{{no calls to 'async' functions occur within 'await' expression}}
48+
__await acceptAutoclosureNonAsync(getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
49+
// expected-warning@-1{{no calls to 'async' functions occur within 'await' expression}}
50+
}

0 commit comments

Comments
 (0)