@@ -9139,20 +9139,22 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9139
9139
auto *parsedSequence = stmt->getParsedSequence ();
9140
9140
bool isAsync = stmt->getAwaitLoc ().isValid ();
9141
9141
9142
- // Simplify the various types.
9143
- forEachStmtInfo.sequenceType =
9144
- solution.simplifyType (forEachStmtInfo.sequenceType );
9145
- forEachStmtInfo.elementType =
9146
- solution.simplifyType (forEachStmtInfo.elementType );
9147
- forEachStmtInfo.initType =
9148
- solution.simplifyType (forEachStmtInfo.initType );
9149
-
9150
9142
auto &cs = solution.getConstraintSystem ();
9151
9143
auto *dc = target.getDeclContext ();
9152
9144
9153
- // First, let's apply the solution to the sequence expression.
9154
- {
9155
- auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar ;
9145
+ if (forEachStmtInfo.isa <SequenceIterationInfo>()) {
9146
+ auto sequenceIterationInfo =
9147
+ *forEachStmtInfo.dyn_cast <SequenceIterationInfo>();
9148
+ // Simplify the various types.
9149
+ sequenceIterationInfo.sequenceType =
9150
+ solution.simplifyType (sequenceIterationInfo.sequenceType );
9151
+ sequenceIterationInfo.elementType =
9152
+ solution.simplifyType (sequenceIterationInfo.elementType );
9153
+ sequenceIterationInfo.initType =
9154
+ solution.simplifyType (sequenceIterationInfo.initType );
9155
+
9156
+ // First, let's apply the solution to the expression.
9157
+ auto *makeIteratorVar = sequenceIterationInfo.makeIteratorVar ;
9156
9158
9157
9159
auto makeIteratorTarget = *cs.getTargetFor ({makeIteratorVar, /* index=*/ 0 });
9158
9160
@@ -9167,127 +9169,126 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9167
9169
}
9168
9170
9169
9171
stmt->setIteratorVar (makeIteratorVar);
9170
- }
9171
9172
9172
- // Now, `$iterator.next()` call.
9173
- {
9174
- auto nextTarget = *cs.getTargetFor (forEachStmtInfo .nextCall );
9173
+ // Now, `$iterator.next()` call.
9174
+ {
9175
+ auto nextTarget = *cs.getTargetFor (sequenceIterationInfo .nextCall );
9175
9176
9176
- auto rewrittenTarget = rewriteTarget (nextTarget);
9177
- if (!rewrittenTarget)
9178
- return llvm::None;
9177
+ auto rewrittenTarget = rewriteTarget (nextTarget);
9178
+ if (!rewrittenTarget)
9179
+ return llvm::None;
9179
9180
9180
- Expr *nextCall = rewrittenTarget->getAsExpr ();
9181
- // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9182
- // witness could be `async throws`.
9183
- if (isAsync) {
9184
- // Cannot use `forEachChildExpr` here because we need to
9185
- // to wrap a call in `try` and then stop immediately after.
9186
- struct TryInjector : ASTWalker {
9187
- ASTContext &C;
9188
- const Solution &S;
9181
+ Expr *nextCall = rewrittenTarget->getAsExpr ();
9182
+ // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9183
+ // witness could be `async throws`.
9184
+ if (isAsync) {
9185
+ // Cannot use `forEachChildExpr` here because we need to
9186
+ // to wrap a call in `try` and then stop immediately after.
9187
+ struct TryInjector : ASTWalker {
9188
+ ASTContext &C;
9189
+ const Solution &S;
9189
9190
9190
- bool ShouldStop = false ;
9191
+ bool ShouldStop = false ;
9191
9192
9192
- TryInjector (ASTContext &ctx, const Solution &solution)
9193
- : C(ctx), S(solution) {}
9193
+ TryInjector (ASTContext &ctx, const Solution &solution)
9194
+ : C(ctx), S(solution) {}
9194
9195
9195
- MacroWalking getMacroWalkingBehavior () const override {
9196
- return MacroWalking::Expansion;
9197
- }
9196
+ MacroWalking getMacroWalkingBehavior () const override {
9197
+ return MacroWalking::Expansion;
9198
+ }
9198
9199
9199
- PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
9200
- if (ShouldStop)
9201
- return Action::Stop ();
9202
-
9203
- if (auto *call = dyn_cast<CallExpr>(E)) {
9204
- // There is a single call expression in `nextCall`.
9205
- ShouldStop = true ;
9206
-
9207
- auto nextRefType =
9208
- S.getResolvedType (call->getFn ())->castTo <FunctionType>();
9209
-
9210
- // If the inferred witness is throwing, we need to wrap the call
9211
- // into `try` expression.
9212
- if (nextRefType->isThrowing ()) {
9213
- auto *tryExpr = TryExpr::createImplicit (
9214
- C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
9215
- // Cannot stop here because we need to make sure that
9216
- // the new expression gets injected into AST.
9217
- return Action::SkipChildren (tryExpr);
9200
+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
9201
+ if (ShouldStop)
9202
+ return Action::Stop ();
9203
+
9204
+ if (auto *call = dyn_cast<CallExpr>(E)) {
9205
+ // There is a single call expression in `nextCall`.
9206
+ ShouldStop = true ;
9207
+
9208
+ auto nextRefType =
9209
+ S.getResolvedType (call->getFn ())->castTo <FunctionType>();
9210
+
9211
+ // If the inferred witness is throwing, we need to wrap the call
9212
+ // into `try` expression.
9213
+ if (nextRefType->isThrowing ()) {
9214
+ auto *tryExpr = TryExpr::createImplicit (
9215
+ C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
9216
+ // Cannot stop here because we need to make sure that
9217
+ // the new expression gets injected into AST.
9218
+ return Action::SkipChildren (tryExpr);
9219
+ }
9218
9220
}
9221
+
9222
+ return Action::Continue (E);
9219
9223
}
9224
+ };
9220
9225
9221
- return Action::Continue (E);
9222
- }
9223
- };
9226
+ nextCall->walk (TryInjector (cs.getASTContext (), solution));
9227
+ }
9224
9228
9225
- nextCall-> walk ( TryInjector (cs. getASTContext (), solution) );
9229
+ stmt-> setNextCall (nextCall );
9226
9230
}
9227
9231
9228
- stmt->setNextCall (nextCall);
9229
- }
9230
-
9231
- // Coerce the pattern to the element type.
9232
- {
9233
- TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9234
- options |= TypeResolutionFlags::OverrideType;
9232
+ // Coerce the pattern to the element type.
9233
+ {
9234
+ TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9235
+ options |= TypeResolutionFlags::OverrideType;
9235
9236
9236
- auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9237
- return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9238
- };
9237
+ auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9238
+ return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9239
+ };
9239
9240
9240
- // Apply the solution to the pattern as well.
9241
- auto contextualPattern = target.getContextualPattern ();
9242
- auto coercedPattern = TypeChecker::coercePatternToType (
9243
- contextualPattern, forEachStmtInfo .initType , options,
9244
- tryRewritePattern);
9245
- if (!coercedPattern)
9246
- return llvm::None;
9241
+ // Apply the solution to the pattern as well.
9242
+ auto contextualPattern = target.getContextualPattern ();
9243
+ auto coercedPattern = TypeChecker::coercePatternToType (
9244
+ contextualPattern, sequenceIterationInfo .initType , options,
9245
+ tryRewritePattern);
9246
+ if (!coercedPattern)
9247
+ return llvm::None;
9247
9248
9248
- stmt->setPattern (coercedPattern);
9249
- resultTarget.setPattern (coercedPattern);
9250
- }
9249
+ stmt->setPattern (coercedPattern);
9250
+ resultTarget.setPattern (coercedPattern);
9251
+ }
9251
9252
9252
- // Apply the solution to the filtering condition, if there is one.
9253
- if (auto *whereExpr = stmt->getWhere ()) {
9254
- auto whereTarget = *cs.getTargetFor (whereExpr);
9253
+ // Apply the solution to the filtering condition, if there is one.
9254
+ if (auto *whereExpr = stmt->getWhere ()) {
9255
+ auto whereTarget = *cs.getTargetFor (whereExpr);
9255
9256
9256
- auto rewrittenTarget = rewriteTarget (whereTarget);
9257
- if (!rewrittenTarget)
9258
- return llvm::None;
9257
+ auto rewrittenTarget = rewriteTarget (whereTarget);
9258
+ if (!rewrittenTarget)
9259
+ return llvm::None;
9259
9260
9260
- stmt->setWhere (rewrittenTarget->getAsExpr ());
9261
- }
9261
+ stmt->setWhere (rewrittenTarget->getAsExpr ());
9262
+ }
9262
9263
9263
- // Convert that llvm::Optional<Element> value to the type of the pattern.
9264
- auto optPatternType = OptionalType::get (forEachStmtInfo.initType );
9265
- Type nextResultType = OptionalType::get (forEachStmtInfo.elementType );
9266
- if (!optPatternType->isEqual (nextResultType)) {
9267
- ASTContext &ctx = cs.getASTContext ();
9268
- OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr (
9269
- stmt->getInLoc (), nextResultType->getOptionalObjectType (),
9270
- /* isPlaceholder=*/ true );
9271
- Expr *convertElementExpr = elementExpr;
9272
- if (TypeChecker::typeCheckExpression (
9273
- convertElementExpr, dc,
9274
- /* contextualInfo=*/ {forEachStmtInfo.initType , CTP_CoerceOperand})
9275
- .isNull ()) {
9276
- return llvm::None;
9264
+ // Convert that llvm::Optional<Element> value to the type of the pattern.
9265
+ auto optPatternType = OptionalType::get (sequenceIterationInfo.initType );
9266
+ Type nextResultType = OptionalType::get (sequenceIterationInfo.elementType );
9267
+ if (!optPatternType->isEqual (nextResultType)) {
9268
+ ASTContext &ctx = cs.getASTContext ();
9269
+ OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr (
9270
+ stmt->getInLoc (), nextResultType->getOptionalObjectType (),
9271
+ /* isPlaceholder=*/ true );
9272
+ Expr *convertElementExpr = elementExpr;
9273
+ if (TypeChecker::typeCheckExpression (
9274
+ convertElementExpr, dc,
9275
+ /* contextualInfo=*/
9276
+ {sequenceIterationInfo.initType , CTP_CoerceOperand})
9277
+ .isNull ()) {
9278
+ return llvm::None;
9279
+ }
9280
+ elementExpr->setIsPlaceholder (false );
9281
+ stmt->setElementExpr (elementExpr);
9282
+ stmt->setConvertElementExpr (convertElementExpr);
9277
9283
}
9278
- elementExpr->setIsPlaceholder (false );
9279
- stmt->setElementExpr (elementExpr);
9280
- stmt->setConvertElementExpr (convertElementExpr);
9281
- }
9282
9284
9283
- // Get the conformance of the sequence type to the Sequence protocol.
9284
- {
9285
+ // Get the conformance of the sequence type to the Sequence protocol.
9285
9286
auto sequenceProto = TypeChecker::getProtocol (
9286
9287
cs.getASTContext (), stmt->getForLoc (),
9287
9288
stmt->getAwaitLoc ().isValid () ? KnownProtocolKind::AsyncSequence
9288
9289
: KnownProtocolKind::Sequence);
9289
9290
9290
- auto type = forEachStmtInfo .sequenceType ->getRValueType ();
9291
+ auto type = sequenceIterationInfo .sequenceType ->getRValueType ();
9291
9292
if (type->isExistentialType ()) {
9292
9293
auto *contextualLoc = solution.getConstraintLocator (
9293
9294
parsedSequence, LocatorPathElt::ContextualType (CTP_ForEachSequence));
@@ -9300,6 +9301,48 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
9300
9301
stmt->setSequenceConformance (sequenceConformance);
9301
9302
}
9302
9303
9304
+ if (forEachStmtInfo.isa <PackIterationInfo>()) {
9305
+ auto packIterationInfo = *forEachStmtInfo.dyn_cast <PackIterationInfo>();
9306
+
9307
+ // First, let's apply the solution to the expression.
9308
+ auto makeSequenceTarget = *cs.getTargetFor (parsedSequence);
9309
+ auto rewrittenTarget = rewriteTarget (makeSequenceTarget);
9310
+ if (!rewrittenTarget)
9311
+ return llvm::None;
9312
+
9313
+ // Coerce the pattern to the element type.
9314
+ {
9315
+ TypeResolutionOptions options (TypeResolverContext::ForEachStmt);
9316
+ options |= TypeResolutionFlags::OverrideType;
9317
+
9318
+ auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9319
+ return ::tryRewriteExprPattern (EP, solution, ty, rewriteTarget);
9320
+ };
9321
+
9322
+ // Apply the solution to the pattern as well.
9323
+ auto contextualPattern = target.getContextualPattern ();
9324
+ auto coercedPattern = TypeChecker::coercePatternToType (
9325
+ contextualPattern, packIterationInfo.expansion ->getType (), options,
9326
+ tryRewritePattern);
9327
+ if (!coercedPattern)
9328
+ return llvm::None;
9329
+
9330
+ stmt->setPattern (coercedPattern);
9331
+ resultTarget.setPattern (coercedPattern);
9332
+ }
9333
+
9334
+ // Apply the solution to the filtering condition, if there is one.
9335
+ if (auto *whereExpr = stmt->getWhere ()) {
9336
+ auto whereTarget = *cs.getTargetFor (whereExpr);
9337
+
9338
+ auto rewrittenTarget = rewriteTarget (whereTarget);
9339
+ if (!rewrittenTarget)
9340
+ return llvm::None;
9341
+
9342
+ stmt->setWhere (rewrittenTarget->getAsExpr ());
9343
+ }
9344
+ }
9345
+
9303
9346
return resultTarget;
9304
9347
}
9305
9348
0 commit comments