@@ -9127,20 +9127,22 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9127
9127
auto *parsedSequence = stmt->getParsedSequence ();
9128
9128
bool isAsync = stmt->getAwaitLoc ().isValid ();
9129
9129
9130
- // Simplify the various types.
9131
- forEachStmtInfo.sequenceType =
9132
- solution.simplifyType (forEachStmtInfo.sequenceType );
9133
- forEachStmtInfo.elementType =
9134
- solution.simplifyType (forEachStmtInfo.elementType );
9135
- forEachStmtInfo.initType =
9136
- solution.simplifyType (forEachStmtInfo.initType );
9137
-
9138
9130
auto &cs = solution.getConstraintSystem ();
9139
9131
auto *dc = target.getDeclContext ();
9140
9132
9141
- // First, let's apply the solution to the sequence expression.
9142
- {
9143
- auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar ;
9133
+ if (forEachStmtInfo.isa <SequenceIterationInfo>()) {
9134
+ auto sequenceIterationInfo =
9135
+ *forEachStmtInfo.dyn_cast <SequenceIterationInfo>();
9136
+ // Simplify the various types.
9137
+ sequenceIterationInfo.sequenceType =
9138
+ solution.simplifyType (sequenceIterationInfo.sequenceType );
9139
+ sequenceIterationInfo.elementType =
9140
+ solution.simplifyType (sequenceIterationInfo.elementType );
9141
+ sequenceIterationInfo.initType =
9142
+ solution.simplifyType (sequenceIterationInfo.initType );
9143
+
9144
+ // First, let's apply the solution to the expression.
9145
+ auto *makeIteratorVar = sequenceIterationInfo.makeIteratorVar ;
9144
9146
9145
9147
auto makeIteratorTarget = *cs.getTargetFor ({makeIteratorVar, /* index=*/ 0 });
9146
9148
@@ -9155,127 +9157,126 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9155
9157
}
9156
9158
9157
9159
stmt->setIteratorVar (makeIteratorVar);
9158
- }
9159
9160
9160
- // Now, `$iterator.next()` call.
9161
- {
9162
- auto nextTarget = *cs.getTargetFor (forEachStmtInfo .nextCall );
9161
+ // Now, `$iterator.next()` call.
9162
+ {
9163
+ auto nextTarget = *cs.getTargetFor (sequenceIterationInfo .nextCall );
9163
9164
9164
- auto rewrittenTarget = rewriteTarget (nextTarget);
9165
- if (!rewrittenTarget)
9166
- return llvm::None;
9165
+ auto rewrittenTarget = rewriteTarget (nextTarget);
9166
+ if (!rewrittenTarget)
9167
+ return llvm::None;
9167
9168
9168
- Expr *nextCall = rewrittenTarget->getAsExpr ();
9169
- // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9170
- // witness could be `async throws`.
9171
- if (isAsync) {
9172
- // Cannot use `forEachChildExpr` here because we need to
9173
- // to wrap a call in `try` and then stop immediately after.
9174
- struct TryInjector : ASTWalker {
9175
- ASTContext &C;
9176
- const Solution &S;
9169
+ Expr *nextCall = rewrittenTarget->getAsExpr ();
9170
+ // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9171
+ // witness could be `async throws`.
9172
+ if (isAsync) {
9173
+ // Cannot use `forEachChildExpr` here because we need to
9174
+ // to wrap a call in `try` and then stop immediately after.
9175
+ struct TryInjector : ASTWalker {
9176
+ ASTContext &C;
9177
+ const Solution &S;
9177
9178
9178
- bool ShouldStop = false ;
9179
+ bool ShouldStop = false ;
9179
9180
9180
- TryInjector (ASTContext &ctx, const Solution &solution)
9181
- : C(ctx), S(solution) {}
9181
+ TryInjector (ASTContext &ctx, const Solution &solution)
9182
+ : C(ctx), S(solution) {}
9182
9183
9183
- MacroWalking getMacroWalkingBehavior () const override {
9184
- return MacroWalking::Expansion;
9185
- }
9184
+ MacroWalking getMacroWalkingBehavior () const override {
9185
+ return MacroWalking::Expansion;
9186
+ }
9186
9187
9187
- PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
9188
- if (ShouldStop)
9189
- return Action::Stop ();
9190
-
9191
- if (auto *call = dyn_cast<CallExpr>(E)) {
9192
- // There is a single call expression in `nextCall`.
9193
- ShouldStop = true ;
9194
-
9195
- auto nextRefType =
9196
- S.getResolvedType (call->getFn ())->castTo <FunctionType>();
9197
-
9198
- // If the inferred witness is throwing, we need to wrap the call
9199
- // into `try` expression.
9200
- if (nextRefType->isThrowing ()) {
9201
- auto *tryExpr = TryExpr::createImplicit (
9202
- C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
9203
- // Cannot stop here because we need to make sure that
9204
- // the new expression gets injected into AST.
9205
- return Action::SkipChildren (tryExpr);
9188
+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
9189
+ if (ShouldStop)
9190
+ return Action::Stop ();
9191
+
9192
+ if (auto *call = dyn_cast<CallExpr>(E)) {
9193
+ // There is a single call expression in `nextCall`.
9194
+ ShouldStop = true ;
9195
+
9196
+ auto nextRefType =
9197
+ S.getResolvedType (call->getFn ())->castTo <FunctionType>();
9198
+
9199
+ // If the inferred witness is throwing, we need to wrap the call
9200
+ // into `try` expression.
9201
+ if (nextRefType->isThrowing ()) {
9202
+ auto *tryExpr = TryExpr::createImplicit (
9203
+ C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
9204
+ // Cannot stop here because we need to make sure that
9205
+ // the new expression gets injected into AST.
9206
+ return Action::SkipChildren (tryExpr);
9207
+ }
9206
9208
}
9209
+
9210
+ return Action::Continue (E);
9207
9211
}
9212
+ };
9208
9213
9209
- return Action::Continue (E);
9210
- }
9211
- };
9214
+ nextCall->walk (TryInjector (cs.getASTContext (), solution));
9215
+ }
9212
9216
9213
- nextCall-> walk ( TryInjector (cs. getASTContext (), solution) );
9217
+ stmt-> setNextCall (nextCall );
9214
9218
}
9215
9219
9216
- stmt->setNextCall (nextCall);
9217
- }
9218
-
9219
- // Coerce the pattern to the element type.
9220
- {
9221
- TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9222
- options |= TypeResolutionFlags::OverrideType;
9220
+ // Coerce the pattern to the element type.
9221
+ {
9222
+ TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9223
+ options |= TypeResolutionFlags::OverrideType;
9223
9224
9224
- auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9225
- return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9226
- };
9225
+ auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9226
+ return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9227
+ };
9227
9228
9228
- // Apply the solution to the pattern as well.
9229
- auto contextualPattern = target.getContextualPattern ();
9230
- auto coercedPattern = TypeChecker::coercePatternToType (
9231
- contextualPattern, forEachStmtInfo .initType , options,
9232
- tryRewritePattern);
9233
- if (!coercedPattern)
9234
- return llvm::None;
9229
+ // Apply the solution to the pattern as well.
9230
+ auto contextualPattern = target.getContextualPattern ();
9231
+ auto coercedPattern = TypeChecker::coercePatternToType (
9232
+ contextualPattern, sequenceIterationInfo .initType , options,
9233
+ tryRewritePattern);
9234
+ if (!coercedPattern)
9235
+ return llvm::None;
9235
9236
9236
- stmt->setPattern (coercedPattern);
9237
- resultTarget.setPattern (coercedPattern);
9238
- }
9237
+ stmt->setPattern (coercedPattern);
9238
+ resultTarget.setPattern (coercedPattern);
9239
+ }
9239
9240
9240
- // Apply the solution to the filtering condition, if there is one.
9241
- if (auto *whereExpr = stmt->getWhere ()) {
9242
- auto whereTarget = *cs.getTargetFor (whereExpr);
9241
+ // Apply the solution to the filtering condition, if there is one.
9242
+ if (auto *whereExpr = stmt->getWhere ()) {
9243
+ auto whereTarget = *cs.getTargetFor (whereExpr);
9243
9244
9244
- auto rewrittenTarget = rewriteTarget (whereTarget);
9245
- if (!rewrittenTarget)
9246
- return llvm::None;
9245
+ auto rewrittenTarget = rewriteTarget (whereTarget);
9246
+ if (!rewrittenTarget)
9247
+ return llvm::None;
9247
9248
9248
- stmt->setWhere (rewrittenTarget->getAsExpr ());
9249
- }
9249
+ stmt->setWhere (rewrittenTarget->getAsExpr ());
9250
+ }
9250
9251
9251
- // Convert that llvm::Optional<Element> value to the type of the pattern.
9252
- auto optPatternType = OptionalType::get (forEachStmtInfo.initType );
9253
- Type nextResultType = OptionalType::get (forEachStmtInfo.elementType );
9254
- if (!optPatternType->isEqual (nextResultType)) {
9255
- ASTContext &ctx = cs.getASTContext ();
9256
- OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr (
9257
- stmt->getInLoc (), nextResultType->getOptionalObjectType (),
9258
- /* isPlaceholder=*/ true );
9259
- Expr *convertElementExpr = elementExpr;
9260
- if (TypeChecker::typeCheckExpression (
9261
- convertElementExpr, dc,
9262
- /* contextualInfo=*/ {forEachStmtInfo.initType , CTP_CoerceOperand})
9263
- .isNull ()) {
9264
- return llvm::None;
9252
+ // Convert that llvm::Optional<Element> value to the type of the pattern.
9253
+ auto optPatternType = OptionalType::get (sequenceIterationInfo.initType );
9254
+ Type nextResultType = OptionalType::get (sequenceIterationInfo.elementType );
9255
+ if (!optPatternType->isEqual (nextResultType)) {
9256
+ ASTContext &ctx = cs.getASTContext ();
9257
+ OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr (
9258
+ stmt->getInLoc (), nextResultType->getOptionalObjectType (),
9259
+ /* isPlaceholder=*/ true );
9260
+ Expr *convertElementExpr = elementExpr;
9261
+ if (TypeChecker::typeCheckExpression (
9262
+ convertElementExpr, dc,
9263
+ /* contextualInfo=*/
9264
+ {sequenceIterationInfo.initType , CTP_CoerceOperand})
9265
+ .isNull ()) {
9266
+ return llvm::None;
9267
+ }
9268
+ elementExpr->setIsPlaceholder (false );
9269
+ stmt->setElementExpr (elementExpr);
9270
+ stmt->setConvertElementExpr (convertElementExpr);
9265
9271
}
9266
- elementExpr->setIsPlaceholder (false );
9267
- stmt->setElementExpr (elementExpr);
9268
- stmt->setConvertElementExpr (convertElementExpr);
9269
- }
9270
9272
9271
- // Get the conformance of the sequence type to the Sequence protocol.
9272
- {
9273
+ // Get the conformance of the sequence type to the Sequence protocol.
9273
9274
auto sequenceProto = TypeChecker::getProtocol (
9274
9275
cs.getASTContext (), stmt->getForLoc (),
9275
9276
stmt->getAwaitLoc ().isValid () ? KnownProtocolKind::AsyncSequence
9276
9277
: KnownProtocolKind::Sequence);
9277
9278
9278
- auto type = forEachStmtInfo .sequenceType ->getRValueType ();
9279
+ auto type = sequenceIterationInfo .sequenceType ->getRValueType ();
9279
9280
if (type->isExistentialType ()) {
9280
9281
auto *contextualLoc = solution.getConstraintLocator (
9281
9282
parsedSequence, LocatorPathElt::ContextualType (CTP_ForEachSequence));
@@ -9287,6 +9288,51 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9287
9288
" Couldn't find sequence conformance" );
9288
9289
stmt->setSequenceConformance (sequenceConformance);
9289
9290
}
9291
+
9292
+ auto &ctx = cs.getASTContext ();
9293
+ if (ctx.LangOpts .hasFeature (Feature::PackIteration)) {
9294
+ if (forEachStmtInfo.isa <PackIterationInfo>()) {
9295
+ auto packIterationInfo = *forEachStmtInfo.dyn_cast <PackIterationInfo>();
9296
+
9297
+ // First, let's apply the solution to the expression.
9298
+ auto makeSequenceTarget = *cs.getTargetFor (parsedSequence);
9299
+ auto rewrittenTarget = rewriteTarget (makeSequenceTarget);
9300
+ if (!rewrittenTarget)
9301
+ return llvm::None;
9302
+
9303
+ // Coerce the pattern to the element type.
9304
+ {
9305
+ TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9306
+ options |= TypeResolutionFlags::OverrideType;
9307
+
9308
+ auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9309
+ return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9310
+ };
9311
+
9312
+ // Apply the solution to the pattern as well.
9313
+ auto contextualPattern = target.getContextualPattern ();
9314
+ auto coercedPattern = TypeChecker::coercePatternToType (
9315
+ contextualPattern, packIterationInfo.expansion ->getType (), options,
9316
+ tryRewritePattern);
9317
+ if (!coercedPattern)
9318
+ return llvm::None;
9319
+
9320
+ stmt->setPattern (coercedPattern);
9321
+ resultTarget.setPattern (coercedPattern);
9322
+ }
9323
+
9324
+ // Apply the solution to the filtering condition, if there is one.
9325
+ if (auto *whereExpr = stmt->getWhere ()) {
9326
+ auto whereTarget = *cs.getTargetFor (whereExpr);
9327
+
9328
+ auto rewrittenTarget = rewriteTarget (whereTarget);
9329
+ if (!rewrittenTarget)
9330
+ return llvm::None;
9331
+
9332
+ stmt->setWhere (rewrittenTarget->getAsExpr ());
9333
+ }
9334
+ }
9335
+ }
9290
9336
9291
9337
return resultTarget;
9292
9338
}
0 commit comments