Skip to content

Commit faec07e

Browse files
authored
Merge pull request #35761 from phausler/async_sequence_conformance_fixes
[Sema] Corrections for for-await-in syntax to prevent specific bad code-gen scenarios and improve diagnostics
2 parents 5013d20 + 8d9d099 commit faec07e

File tree

4 files changed

+93
-25
lines changed

4 files changed

+93
-25
lines changed

lib/Sema/ConstraintSystem.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,13 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
22952295
return { false, stmt };
22962296
}
22972297

2298+
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
2299+
if (forEach->getTryLoc().isValid()) {
2300+
FoundThrow = true;
2301+
return { false, nullptr };
2302+
}
2303+
}
2304+
22982305
return { true, stmt };
22992306
}
23002307

@@ -2332,6 +2339,17 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
23322339
return true;
23332340
}
23342341

2342+
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
2343+
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
2344+
if (forEach->getAwaitLoc().isValid()) {
2345+
FoundAsync = true;
2346+
return { false, nullptr };
2347+
}
2348+
}
2349+
2350+
return { true, stmt };
2351+
}
2352+
23352353
public:
23362354
bool foundAsync() { return FoundAsync; }
23372355
};

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -592,37 +592,17 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
592592
// check to see if the sequence expr is throwing (and async), if so require
593593
// the stmt to have a try loc
594594
if (stmt->getAwaitLoc().isValid()) {
595-
auto Ty = sequence->getType();
596-
if (Ty.isNull()) {
597-
auto DRE = dyn_cast<DeclRefExpr>(sequence);
598-
if (DRE) {
599-
Ty = DRE->getDecl()->getInterfaceType();
600-
}
601-
if (Ty.isNull()) {
602-
return failed();
603-
}
604-
}
605-
auto context = Ty->getNominalOrBoundGenericNominal();
606-
if (!context) {
607-
// if no nominal type can be determined then we must consider this to be
608-
// a potential throwing source and concequently this must have a valid try
609-
// location to account for that potential ambiguity.
610-
if (stmt->getTryLoc().isInvalid()) {
611-
auto &diags = dc->getASTContext().Diags;
612-
diags.diagnose(stmt->getAwaitLoc(), diag::throwing_call_unhandled);
613-
return failed();
614-
} else {
615-
return false;
616-
}
617-
618-
}
595+
// fetch the sequence out of the statement
596+
// else wise the value is potentially unresolved
597+
auto Ty = stmt->getSequence()->getType();
619598
auto module = dc->getParentModule();
620599
auto conformanceRef = module->lookupConformance(Ty, sequenceProto);
621600

622601
if (conformanceRef.classifyAsThrows() &&
623602
stmt->getTryLoc().isInvalid()) {
624603
auto &diags = dc->getASTContext().Diags;
625-
diags.diagnose(stmt->getAwaitLoc(), diag::throwing_call_unhandled);
604+
diags.diagnose(stmt->getAwaitLoc(), diag::throwing_call_unhandled)
605+
.fixItInsert(stmt->getAwaitLoc(), "try");
626606

627607
return failed();
628608
}

lib/Sema/TypeCheckEffects.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,13 @@ class ApplyClassifier {
528528
}
529529

530530
if (classifiedAsThrows) {
531+
// multiple passes can occur, so ensure that any sub-expressions of this
532+
// call are marked as throws to mimic the closure variant.
533+
if (auto subExpr = dyn_cast<ApplyExpr>(E->getFn())) {
534+
if (!subExpr->isThrowsSet()) {
535+
subExpr->setThrows(true);
536+
}
537+
}
531538
return Classification::forRethrowingOnly(
532539
PotentialThrowReason::forRethrowsConformance(E), isAsync);
533540
}
@@ -2096,6 +2103,12 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
20962103
Classification::forThrow(PotentialThrowReason::forThrow(),
20972104
/*async*/false));
20982105
}
2106+
if (S->getAwaitLoc().isValid() &&
2107+
!Flags.has(ContextFlags::HasAnyAsyncSite)) {
2108+
if (!CurContext.handlesAsync()) {
2109+
CurContext.diagnoseUnhandledAsyncSite(Ctx.Diags, S);
2110+
}
2111+
}
20992112
return ShouldRecurse;
21002113
}
21012114
};
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-concurrency
2+
// REQUIRES: concurrency
3+
4+
// expected-note@+1{{add 'async' to function 'missingAsync' to make it asynchronous}}
5+
func missingAsync<T : AsyncSequence>(_ seq: T) throws {
6+
for try await _ in seq { } // expected-error{{'async' in a function that does not support concurrency}}
7+
}
8+
9+
func missingThrows<T : AsyncSequence>(_ seq: T) async {
10+
for try await _ in seq { } // expected-error{{error is not handled because the enclosing function is not declared 'throws'}}
11+
}
12+
13+
func executeAsync(_ work: () async -> Void) { }
14+
func execute(_ work: () -> Void) { }
15+
16+
func missingThrowingInBlock<T : AsyncSequence>(_ seq: T) {
17+
executeAsync { // expected-error{{invalid conversion from throwing function of type '() async throws -> Void' to non-throwing function type '() async -> Void'}}
18+
for try await _ in seq { }
19+
}
20+
}
21+
22+
func missingTryInBlock<T : AsyncSequence>(_ seq: T) {
23+
executeAsync {
24+
for await _ in seq { } // expected-error{{call can throw, but the error is not handled}}
25+
}
26+
}
27+
28+
func missingAsyncInBlock<T : AsyncSequence>(_ seq: T) {
29+
execute { // expected-error{{invalid conversion from 'async' function of type '() async -> Void' to synchronous function type '() -> Void'}}
30+
do {
31+
for try await _ in seq { }
32+
} catch { }
33+
}
34+
}
35+
36+
func doubleDiagCheckGeneric<T : AsyncSequence>(_ seq: T) async {
37+
var it = seq.makeAsyncIterator()
38+
// expected-note@+2{{call is to 'rethrows' function, but a conformance has a throwing witness}}
39+
// expected-error@+1{{call can throw, but it is not marked with 'try' and the error is not handled}}
40+
let _ = await it.next()
41+
}
42+
43+
struct ThrowingAsyncSequence: AsyncSequence, AsyncIteratorProtocol {
44+
typealias Element = Int
45+
typealias AsyncIterator = Self
46+
mutating func next() async throws -> Int? {
47+
return nil
48+
}
49+
50+
func makeAsyncIterator() -> Self { return self }
51+
}
52+
53+
func doubleDiagCheckConcrete(_ seq: ThrowingAsyncSequence) async {
54+
var it = seq.makeAsyncIterator()
55+
// expected-error@+1{{call can throw, but it is not marked with 'try' and the error is not handled}}
56+
let _ = await it.next()
57+
}

0 commit comments

Comments
 (0)