@@ -8929,21 +8929,46 @@ static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
8929
8929
8930
8930
Expr *nextCall = rewrittenTarget->getAsExpr ();
8931
8931
// Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
8932
- // requirement is `async throws`
8932
+ // witness could be `async throws`.
8933
8933
if (isAsync) {
8934
- auto &ctx = cs.getASTContext ();
8935
- auto nextRefType =
8936
- solution
8937
- .getResolvedType (
8938
- cast<ApplyExpr>(cast<AwaitExpr>(nextCall)->getSubExpr ())
8939
- ->getFn ())
8940
- ->castTo <FunctionType>();
8941
-
8942
- // If the inferred witness is throwing, we need to wrap the call
8943
- // into `try` expression.
8944
- if (nextRefType->isThrowing ())
8945
- nextCall = TryExpr::createImplicit (ctx, /* tryLoc=*/ SourceLoc (),
8946
- nextCall, nextCall->getType ());
8934
+ // Cannot use `forEachChildExpr` here because we need to
8935
+ // to wrap a call in `try` and then stop immediately after.
8936
+ struct TryInjector : ASTWalker {
8937
+ ASTContext &C;
8938
+ const Solution &S;
8939
+
8940
+ bool ShouldStop = false ;
8941
+
8942
+ TryInjector (ASTContext &ctx, const Solution &solution)
8943
+ : C(ctx), S(solution) {}
8944
+
8945
+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
8946
+ if (ShouldStop)
8947
+ return Action::Stop ();
8948
+
8949
+ if (auto *call = dyn_cast<CallExpr>(E)) {
8950
+ // There is a single call expression in `nextCall`.
8951
+ ShouldStop = true ;
8952
+
8953
+ auto nextRefType =
8954
+ S.getResolvedType (call->getFn ())->castTo <FunctionType>();
8955
+
8956
+ // If the inferred witness is throwing, we need to wrap the call
8957
+ // into `try` expression.
8958
+ if (nextRefType->isThrowing ()) {
8959
+ auto *tryExpr = TryExpr::createImplicit (
8960
+ C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
8961
+ // Cannot stop here because we need to make sure that
8962
+ // the new expression gets injected into AST.
8963
+ return Action::SkipChildren (tryExpr);
8964
+ }
8965
+ }
8966
+
8967
+ return Action::Continue (E);
8968
+ }
8969
+ };
8970
+
8971
+ nextCall->walk (TryInjector (cs.getASTContext (), solution));
8947
8972
}
8948
8973
8949
8974
stmt->setNextCall (nextCall);
0 commit comments