Skip to content

Commit 42a4340

Browse files
committed
[Variadic Generics] Enable Pack Iteration
1 parent 9179bc5 commit 42a4340

File tree

6 files changed

+345
-224
lines changed

6 files changed

+345
-224
lines changed

include/swift/Basic/Features.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ EXPERIMENTAL_FEATURE(PlaygroundExtendedCallbacks, true)
227227
/// Enable the `@_rawLayout` attribute.
228228
EXPERIMENTAL_FEATURE(RawLayout, true)
229229

230+
/// Enable pack iteration.
231+
EXPERIMENTAL_FEATURE(PackIteration, true)
232+
230233
#undef EXPERIMENTAL_FEATURE_EXCLUDED_FROM_MODULE_INTERFACE
231234
#undef EXPERIMENTAL_FEATURE
232235
#undef UPCOMING_FEATURE

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
namespace swift {
2929

3030
namespace constraints {
31-
/// Describes information about a for-each loop that needs to be tracked
32-
/// within the constraint system.
33-
struct ForEachStmtInfo {
31+
/// Describes information specific to a sequence
32+
/// in a for-each loop.
33+
struct SequenceIterationInfo {
3434
/// The type of the sequence.
3535
Type sequenceType;
3636

@@ -47,6 +47,23 @@ struct ForEachStmtInfo {
4747
Expr *nextCall;
4848
};
4949

50+
/// Describes information specific to a pack expansion expression
51+
/// in a for-each loop.
52+
struct PackIterationInfo {
53+
/// The pack expansion expression.
54+
PackExpansionExpr *expansion;
55+
56+
/// The mapping from pack types in the outer context to element types inside
57+
/// the for-each loop.
58+
GenericEnvironment *openedElementEnvironment;
59+
};
60+
61+
/// Describes information about a for-each loop that needs to be tracked
62+
/// within the constraint system.
63+
struct ForEachStmtInfo : TaggedUnion<SequenceIterationInfo, PackIterationInfo> {
64+
using TaggedUnion::TaggedUnion;
65+
};
66+
5067
/// Describes the target to which a constraint system's solution can be
5168
/// applied.
5269
class SyntacticElementTarget {

lib/AST/ASTPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,6 +3512,10 @@ static bool usesFeatureNewCxxMethodSafetyHeuristics(Decl *decl) {
35123512
return decl->hasClangNode();
35133513
}
35143514

3515+
static bool usesFeaturePackIteration(Decl *decl) {
3516+
return false;
3517+
}
3518+
35153519
/// Suppress the printing of a particular feature.
35163520
static void suppressingFeature(PrintOptions &options, Feature feature,
35173521
llvm::function_ref<void()> action) {

lib/Sema/CSApply.cpp

Lines changed: 151 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -9127,20 +9127,22 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91279127
auto *parsedSequence = stmt->getParsedSequence();
91289128
bool isAsync = stmt->getAwaitLoc().isValid();
91299129

9130-
// Simplify the various types.
9131-
forEachStmtInfo.sequenceType =
9132-
solution.simplifyType(forEachStmtInfo.sequenceType);
9133-
forEachStmtInfo.elementType =
9134-
solution.simplifyType(forEachStmtInfo.elementType);
9135-
forEachStmtInfo.initType =
9136-
solution.simplifyType(forEachStmtInfo.initType);
9137-
91389130
auto &cs = solution.getConstraintSystem();
91399131
auto *dc = target.getDeclContext();
91409132

9141-
// First, let's apply the solution to the sequence expression.
9142-
{
9143-
auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar;
9133+
if (forEachStmtInfo.isa<SequenceIterationInfo>()) {
9134+
auto sequenceIterationInfo =
9135+
*forEachStmtInfo.dyn_cast<SequenceIterationInfo>();
9136+
// Simplify the various types.
9137+
sequenceIterationInfo.sequenceType =
9138+
solution.simplifyType(sequenceIterationInfo.sequenceType);
9139+
sequenceIterationInfo.elementType =
9140+
solution.simplifyType(sequenceIterationInfo.elementType);
9141+
sequenceIterationInfo.initType =
9142+
solution.simplifyType(sequenceIterationInfo.initType);
9143+
9144+
// First, let's apply the solution to the expression.
9145+
auto *makeIteratorVar = sequenceIterationInfo.makeIteratorVar;
91449146

91459147
auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0});
91469148

@@ -9155,127 +9157,126 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91559157
}
91569158

91579159
stmt->setIteratorVar(makeIteratorVar);
9158-
}
91599160

9160-
// Now, `$iterator.next()` call.
9161-
{
9162-
auto nextTarget = *cs.getTargetFor(forEachStmtInfo.nextCall);
9161+
// Now, `$iterator.next()` call.
9162+
{
9163+
auto nextTarget = *cs.getTargetFor(sequenceIterationInfo.nextCall);
91639164

9164-
auto rewrittenTarget = rewriteTarget(nextTarget);
9165-
if (!rewrittenTarget)
9166-
return llvm::None;
9165+
auto rewrittenTarget = rewriteTarget(nextTarget);
9166+
if (!rewrittenTarget)
9167+
return llvm::None;
91679168

9168-
Expr *nextCall = rewrittenTarget->getAsExpr();
9169-
// Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9170-
// witness could be `async throws`.
9171-
if (isAsync) {
9172-
// Cannot use `forEachChildExpr` here because we need to
9173-
// to wrap a call in `try` and then stop immediately after.
9174-
struct TryInjector : ASTWalker {
9175-
ASTContext &C;
9176-
const Solution &S;
9169+
Expr *nextCall = rewrittenTarget->getAsExpr();
9170+
// Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9171+
// witness could be `async throws`.
9172+
if (isAsync) {
9173+
// Cannot use `forEachChildExpr` here because we need to
9174+
// to wrap a call in `try` and then stop immediately after.
9175+
struct TryInjector : ASTWalker {
9176+
ASTContext &C;
9177+
const Solution &S;
91779178

9178-
bool ShouldStop = false;
9179+
bool ShouldStop = false;
91799180

9180-
TryInjector(ASTContext &ctx, const Solution &solution)
9181-
: C(ctx), S(solution) {}
9181+
TryInjector(ASTContext &ctx, const Solution &solution)
9182+
: C(ctx), S(solution) {}
91829183

9183-
MacroWalking getMacroWalkingBehavior() const override {
9184-
return MacroWalking::Expansion;
9185-
}
9184+
MacroWalking getMacroWalkingBehavior() const override {
9185+
return MacroWalking::Expansion;
9186+
}
91869187

9187-
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
9188-
if (ShouldStop)
9189-
return Action::Stop();
9190-
9191-
if (auto *call = dyn_cast<CallExpr>(E)) {
9192-
// There is a single call expression in `nextCall`.
9193-
ShouldStop = true;
9194-
9195-
auto nextRefType =
9196-
S.getResolvedType(call->getFn())->castTo<FunctionType>();
9197-
9198-
// If the inferred witness is throwing, we need to wrap the call
9199-
// into `try` expression.
9200-
if (nextRefType->isThrowing()) {
9201-
auto *tryExpr = TryExpr::createImplicit(
9202-
C, /*tryLoc=*/call->getStartLoc(), call, call->getType());
9203-
// Cannot stop here because we need to make sure that
9204-
// the new expression gets injected into AST.
9205-
return Action::SkipChildren(tryExpr);
9188+
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
9189+
if (ShouldStop)
9190+
return Action::Stop();
9191+
9192+
if (auto *call = dyn_cast<CallExpr>(E)) {
9193+
// There is a single call expression in `nextCall`.
9194+
ShouldStop = true;
9195+
9196+
auto nextRefType =
9197+
S.getResolvedType(call->getFn())->castTo<FunctionType>();
9198+
9199+
// If the inferred witness is throwing, we need to wrap the call
9200+
// into `try` expression.
9201+
if (nextRefType->isThrowing()) {
9202+
auto *tryExpr = TryExpr::createImplicit(
9203+
C, /*tryLoc=*/call->getStartLoc(), call, call->getType());
9204+
// Cannot stop here because we need to make sure that
9205+
// the new expression gets injected into AST.
9206+
return Action::SkipChildren(tryExpr);
9207+
}
92069208
}
9209+
9210+
return Action::Continue(E);
92079211
}
9212+
};
92089213

9209-
return Action::Continue(E);
9210-
}
9211-
};
9214+
nextCall->walk(TryInjector(cs.getASTContext(), solution));
9215+
}
92129216

9213-
nextCall->walk(TryInjector(cs.getASTContext(), solution));
9217+
stmt->setNextCall(nextCall);
92149218
}
92159219

9216-
stmt->setNextCall(nextCall);
9217-
}
9218-
9219-
// Coerce the pattern to the element type.
9220-
{
9221-
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9222-
options |= TypeResolutionFlags::OverrideType;
9220+
// Coerce the pattern to the element type.
9221+
{
9222+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9223+
options |= TypeResolutionFlags::OverrideType;
92239224

9224-
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9225-
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9226-
};
9225+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9226+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9227+
};
92279228

9228-
// Apply the solution to the pattern as well.
9229-
auto contextualPattern = target.getContextualPattern();
9230-
auto coercedPattern = TypeChecker::coercePatternToType(
9231-
contextualPattern, forEachStmtInfo.initType, options,
9232-
tryRewritePattern);
9233-
if (!coercedPattern)
9234-
return llvm::None;
9229+
// Apply the solution to the pattern as well.
9230+
auto contextualPattern = target.getContextualPattern();
9231+
auto coercedPattern = TypeChecker::coercePatternToType(
9232+
contextualPattern, sequenceIterationInfo.initType, options,
9233+
tryRewritePattern);
9234+
if (!coercedPattern)
9235+
return llvm::None;
92359236

9236-
stmt->setPattern(coercedPattern);
9237-
resultTarget.setPattern(coercedPattern);
9238-
}
9237+
stmt->setPattern(coercedPattern);
9238+
resultTarget.setPattern(coercedPattern);
9239+
}
92399240

9240-
// Apply the solution to the filtering condition, if there is one.
9241-
if (auto *whereExpr = stmt->getWhere()) {
9242-
auto whereTarget = *cs.getTargetFor(whereExpr);
9241+
// Apply the solution to the filtering condition, if there is one.
9242+
if (auto *whereExpr = stmt->getWhere()) {
9243+
auto whereTarget = *cs.getTargetFor(whereExpr);
92439244

9244-
auto rewrittenTarget = rewriteTarget(whereTarget);
9245-
if (!rewrittenTarget)
9246-
return llvm::None;
9245+
auto rewrittenTarget = rewriteTarget(whereTarget);
9246+
if (!rewrittenTarget)
9247+
return llvm::None;
92479248

9248-
stmt->setWhere(rewrittenTarget->getAsExpr());
9249-
}
9249+
stmt->setWhere(rewrittenTarget->getAsExpr());
9250+
}
92509251

9251-
// Convert that llvm::Optional<Element> value to the type of the pattern.
9252-
auto optPatternType = OptionalType::get(forEachStmtInfo.initType);
9253-
Type nextResultType = OptionalType::get(forEachStmtInfo.elementType);
9254-
if (!optPatternType->isEqual(nextResultType)) {
9255-
ASTContext &ctx = cs.getASTContext();
9256-
OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr(
9257-
stmt->getInLoc(), nextResultType->getOptionalObjectType(),
9258-
/*isPlaceholder=*/true);
9259-
Expr *convertElementExpr = elementExpr;
9260-
if (TypeChecker::typeCheckExpression(
9261-
convertElementExpr, dc,
9262-
/*contextualInfo=*/{forEachStmtInfo.initType, CTP_CoerceOperand})
9263-
.isNull()) {
9264-
return llvm::None;
9252+
// Convert that llvm::Optional<Element> value to the type of the pattern.
9253+
auto optPatternType = OptionalType::get(sequenceIterationInfo.initType);
9254+
Type nextResultType = OptionalType::get(sequenceIterationInfo.elementType);
9255+
if (!optPatternType->isEqual(nextResultType)) {
9256+
ASTContext &ctx = cs.getASTContext();
9257+
OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr(
9258+
stmt->getInLoc(), nextResultType->getOptionalObjectType(),
9259+
/*isPlaceholder=*/true);
9260+
Expr *convertElementExpr = elementExpr;
9261+
if (TypeChecker::typeCheckExpression(
9262+
convertElementExpr, dc,
9263+
/*contextualInfo=*/
9264+
{sequenceIterationInfo.initType, CTP_CoerceOperand})
9265+
.isNull()) {
9266+
return llvm::None;
9267+
}
9268+
elementExpr->setIsPlaceholder(false);
9269+
stmt->setElementExpr(elementExpr);
9270+
stmt->setConvertElementExpr(convertElementExpr);
92659271
}
9266-
elementExpr->setIsPlaceholder(false);
9267-
stmt->setElementExpr(elementExpr);
9268-
stmt->setConvertElementExpr(convertElementExpr);
9269-
}
92709272

9271-
// Get the conformance of the sequence type to the Sequence protocol.
9272-
{
9273+
// Get the conformance of the sequence type to the Sequence protocol.
92739274
auto sequenceProto = TypeChecker::getProtocol(
92749275
cs.getASTContext(), stmt->getForLoc(),
92759276
stmt->getAwaitLoc().isValid() ? KnownProtocolKind::AsyncSequence
92769277
: KnownProtocolKind::Sequence);
92779278

9278-
auto type = forEachStmtInfo.sequenceType->getRValueType();
9279+
auto type = sequenceIterationInfo.sequenceType->getRValueType();
92799280
if (type->isExistentialType()) {
92809281
auto *contextualLoc = solution.getConstraintLocator(
92819282
parsedSequence, LocatorPathElt::ContextualType(CTP_ForEachSequence));
@@ -9287,6 +9288,51 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
92879288
"Couldn't find sequence conformance");
92889289
stmt->setSequenceConformance(sequenceConformance);
92899290
}
9291+
9292+
auto &ctx = cs.getASTContext();
9293+
if (ctx.LangOpts.hasFeature(Feature::PackIteration)) {
9294+
if (forEachStmtInfo.isa<PackIterationInfo>()) {
9295+
auto packIterationInfo = *forEachStmtInfo.dyn_cast<PackIterationInfo>();
9296+
9297+
// First, let's apply the solution to the expression.
9298+
auto makeSequenceTarget = *cs.getTargetFor(parsedSequence);
9299+
auto rewrittenTarget = rewriteTarget(makeSequenceTarget);
9300+
if (!rewrittenTarget)
9301+
return llvm::None;
9302+
9303+
// Coerce the pattern to the element type.
9304+
{
9305+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9306+
options |= TypeResolutionFlags::OverrideType;
9307+
9308+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9309+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9310+
};
9311+
9312+
// Apply the solution to the pattern as well.
9313+
auto contextualPattern = target.getContextualPattern();
9314+
auto coercedPattern = TypeChecker::coercePatternToType(
9315+
contextualPattern, packIterationInfo.expansion->getType(), options,
9316+
tryRewritePattern);
9317+
if (!coercedPattern)
9318+
return llvm::None;
9319+
9320+
stmt->setPattern(coercedPattern);
9321+
resultTarget.setPattern(coercedPattern);
9322+
}
9323+
9324+
// Apply the solution to the filtering condition, if there is one.
9325+
if (auto *whereExpr = stmt->getWhere()) {
9326+
auto whereTarget = *cs.getTargetFor(whereExpr);
9327+
9328+
auto rewrittenTarget = rewriteTarget(whereTarget);
9329+
if (!rewrittenTarget)
9330+
return llvm::None;
9331+
9332+
stmt->setWhere(rewrittenTarget->getAsExpr());
9333+
}
9334+
}
9335+
}
92909336

92919337
return resultTarget;
92929338
}

0 commit comments

Comments
 (0)