Skip to content

Commit 4b1b8c5

Browse files
committed
[Variadic Generics][WIP] Enable Pack Iteration
1 parent 7cfc8d6 commit 4b1b8c5

File tree

4 files changed

+335
-225
lines changed

4 files changed

+335
-225
lines changed

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
namespace swift {
2828

2929
namespace constraints {
30-
/// Describes information about a for-each loop that needs to be tracked
31-
/// within the constraint system.
32-
struct ForEachStmtInfo {
30+
/// Describes information specific to a sequence
31+
/// in a for-each loop.
32+
struct SequenceIterationInfo {
3333
/// The type of the sequence.
3434
Type sequenceType;
3535

@@ -46,6 +46,23 @@ struct ForEachStmtInfo {
4646
Expr *nextCall;
4747
};
4848

49+
/// Describes information specific to a pack expansion expression
50+
/// in a for-each loop.
51+
struct PackIterationInfo {
52+
/// The pack expansion expression.
53+
PackExpansionExpr *expansion;
54+
55+
/// The mapping from pack types in the outer context to element types inside
56+
/// the for-each loop.
57+
GenericEnvironment *openedElementEnvironment;
58+
};
59+
60+
/// Describes information about a for-each loop that needs to be tracked
61+
/// within the constraint system.
62+
struct ForEachStmtInfo : TaggedUnion<SequenceIterationInfo, PackIterationInfo> {
63+
using TaggedUnion::TaggedUnion;
64+
};
65+
4966
/// Describes the target to which a constraint system's solution can be
5067
/// applied.
5168
class SyntacticElementTarget {

lib/Sema/CSApply.cpp

Lines changed: 148 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -9139,20 +9139,22 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91399139
auto *parsedSequence = stmt->getParsedSequence();
91409140
bool isAsync = stmt->getAwaitLoc().isValid();
91419141

9142-
// Simplify the various types.
9143-
forEachStmtInfo.sequenceType =
9144-
solution.simplifyType(forEachStmtInfo.sequenceType);
9145-
forEachStmtInfo.elementType =
9146-
solution.simplifyType(forEachStmtInfo.elementType);
9147-
forEachStmtInfo.initType =
9148-
solution.simplifyType(forEachStmtInfo.initType);
9149-
91509142
auto &cs = solution.getConstraintSystem();
91519143
auto *dc = target.getDeclContext();
91529144

9153-
// First, let's apply the solution to the sequence expression.
9154-
{
9155-
auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar;
9145+
if (forEachStmtInfo.isa<SequenceIterationInfo>()) {
9146+
auto sequenceIterationInfo =
9147+
*forEachStmtInfo.dyn_cast<SequenceIterationInfo>();
9148+
// Simplify the various types.
9149+
sequenceIterationInfo.sequenceType =
9150+
solution.simplifyType(sequenceIterationInfo.sequenceType);
9151+
sequenceIterationInfo.elementType =
9152+
solution.simplifyType(sequenceIterationInfo.elementType);
9153+
sequenceIterationInfo.initType =
9154+
solution.simplifyType(sequenceIterationInfo.initType);
9155+
9156+
// First, let's apply the solution to the expression.
9157+
auto *makeIteratorVar = sequenceIterationInfo.makeIteratorVar;
91569158

91579159
auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0});
91589160

@@ -9167,127 +9169,126 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91679169
}
91689170

91699171
stmt->setIteratorVar(makeIteratorVar);
9170-
}
91719172

9172-
// Now, `$iterator.next()` call.
9173-
{
9174-
auto nextTarget = *cs.getTargetFor(forEachStmtInfo.nextCall);
9173+
// Now, `$iterator.next()` call.
9174+
{
9175+
auto nextTarget = *cs.getTargetFor(sequenceIterationInfo.nextCall);
91759176

9176-
auto rewrittenTarget = rewriteTarget(nextTarget);
9177-
if (!rewrittenTarget)
9178-
return llvm::None;
9177+
auto rewrittenTarget = rewriteTarget(nextTarget);
9178+
if (!rewrittenTarget)
9179+
return llvm::None;
91799180

9180-
Expr *nextCall = rewrittenTarget->getAsExpr();
9181-
// Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9182-
// witness could be `async throws`.
9183-
if (isAsync) {
9184-
// Cannot use `forEachChildExpr` here because we need to
9185-
// to wrap a call in `try` and then stop immediately after.
9186-
struct TryInjector : ASTWalker {
9187-
ASTContext &C;
9188-
const Solution &S;
9181+
Expr *nextCall = rewrittenTarget->getAsExpr();
9182+
// Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
9183+
// witness could be `async throws`.
9184+
if (isAsync) {
9185+
// Cannot use `forEachChildExpr` here because we need to
9186+
// to wrap a call in `try` and then stop immediately after.
9187+
struct TryInjector : ASTWalker {
9188+
ASTContext &C;
9189+
const Solution &S;
91899190

9190-
bool ShouldStop = false;
9191+
bool ShouldStop = false;
91919192

9192-
TryInjector(ASTContext &ctx, const Solution &solution)
9193-
: C(ctx), S(solution) {}
9193+
TryInjector(ASTContext &ctx, const Solution &solution)
9194+
: C(ctx), S(solution) {}
91949195

9195-
MacroWalking getMacroWalkingBehavior() const override {
9196-
return MacroWalking::Expansion;
9197-
}
9196+
MacroWalking getMacroWalkingBehavior() const override {
9197+
return MacroWalking::Expansion;
9198+
}
91989199

9199-
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
9200-
if (ShouldStop)
9201-
return Action::Stop();
9202-
9203-
if (auto *call = dyn_cast<CallExpr>(E)) {
9204-
// There is a single call expression in `nextCall`.
9205-
ShouldStop = true;
9206-
9207-
auto nextRefType =
9208-
S.getResolvedType(call->getFn())->castTo<FunctionType>();
9209-
9210-
// If the inferred witness is throwing, we need to wrap the call
9211-
// into `try` expression.
9212-
if (nextRefType->isThrowing()) {
9213-
auto *tryExpr = TryExpr::createImplicit(
9214-
C, /*tryLoc=*/call->getStartLoc(), call, call->getType());
9215-
// Cannot stop here because we need to make sure that
9216-
// the new expression gets injected into AST.
9217-
return Action::SkipChildren(tryExpr);
9200+
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
9201+
if (ShouldStop)
9202+
return Action::Stop();
9203+
9204+
if (auto *call = dyn_cast<CallExpr>(E)) {
9205+
// There is a single call expression in `nextCall`.
9206+
ShouldStop = true;
9207+
9208+
auto nextRefType =
9209+
S.getResolvedType(call->getFn())->castTo<FunctionType>();
9210+
9211+
// If the inferred witness is throwing, we need to wrap the call
9212+
// into `try` expression.
9213+
if (nextRefType->isThrowing()) {
9214+
auto *tryExpr = TryExpr::createImplicit(
9215+
C, /*tryLoc=*/call->getStartLoc(), call, call->getType());
9216+
// Cannot stop here because we need to make sure that
9217+
// the new expression gets injected into AST.
9218+
return Action::SkipChildren(tryExpr);
9219+
}
92189220
}
9221+
9222+
return Action::Continue(E);
92199223
}
9224+
};
92209225

9221-
return Action::Continue(E);
9222-
}
9223-
};
9226+
nextCall->walk(TryInjector(cs.getASTContext(), solution));
9227+
}
92249228

9225-
nextCall->walk(TryInjector(cs.getASTContext(), solution));
9229+
stmt->setNextCall(nextCall);
92269230
}
92279231

9228-
stmt->setNextCall(nextCall);
9229-
}
9230-
9231-
// Coerce the pattern to the element type.
9232-
{
9233-
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9234-
options |= TypeResolutionFlags::OverrideType;
9232+
// Coerce the pattern to the element type.
9233+
{
9234+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9235+
options |= TypeResolutionFlags::OverrideType;
92359236

9236-
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9237-
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9238-
};
9237+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9238+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9239+
};
92399240

9240-
// Apply the solution to the pattern as well.
9241-
auto contextualPattern = target.getContextualPattern();
9242-
auto coercedPattern = TypeChecker::coercePatternToType(
9243-
contextualPattern, forEachStmtInfo.initType, options,
9244-
tryRewritePattern);
9245-
if (!coercedPattern)
9246-
return llvm::None;
9241+
// Apply the solution to the pattern as well.
9242+
auto contextualPattern = target.getContextualPattern();
9243+
auto coercedPattern = TypeChecker::coercePatternToType(
9244+
contextualPattern, sequenceIterationInfo.initType, options,
9245+
tryRewritePattern);
9246+
if (!coercedPattern)
9247+
return llvm::None;
92479248

9248-
stmt->setPattern(coercedPattern);
9249-
resultTarget.setPattern(coercedPattern);
9250-
}
9249+
stmt->setPattern(coercedPattern);
9250+
resultTarget.setPattern(coercedPattern);
9251+
}
92519252

9252-
// Apply the solution to the filtering condition, if there is one.
9253-
if (auto *whereExpr = stmt->getWhere()) {
9254-
auto whereTarget = *cs.getTargetFor(whereExpr);
9253+
// Apply the solution to the filtering condition, if there is one.
9254+
if (auto *whereExpr = stmt->getWhere()) {
9255+
auto whereTarget = *cs.getTargetFor(whereExpr);
92559256

9256-
auto rewrittenTarget = rewriteTarget(whereTarget);
9257-
if (!rewrittenTarget)
9258-
return llvm::None;
9257+
auto rewrittenTarget = rewriteTarget(whereTarget);
9258+
if (!rewrittenTarget)
9259+
return llvm::None;
92599260

9260-
stmt->setWhere(rewrittenTarget->getAsExpr());
9261-
}
9261+
stmt->setWhere(rewrittenTarget->getAsExpr());
9262+
}
92629263

9263-
// Convert that llvm::Optional<Element> value to the type of the pattern.
9264-
auto optPatternType = OptionalType::get(forEachStmtInfo.initType);
9265-
Type nextResultType = OptionalType::get(forEachStmtInfo.elementType);
9266-
if (!optPatternType->isEqual(nextResultType)) {
9267-
ASTContext &ctx = cs.getASTContext();
9268-
OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr(
9269-
stmt->getInLoc(), nextResultType->getOptionalObjectType(),
9270-
/*isPlaceholder=*/true);
9271-
Expr *convertElementExpr = elementExpr;
9272-
if (TypeChecker::typeCheckExpression(
9273-
convertElementExpr, dc,
9274-
/*contextualInfo=*/{forEachStmtInfo.initType, CTP_CoerceOperand})
9275-
.isNull()) {
9276-
return llvm::None;
9264+
// Convert that llvm::Optional<Element> value to the type of the pattern.
9265+
auto optPatternType = OptionalType::get(sequenceIterationInfo.initType);
9266+
Type nextResultType = OptionalType::get(sequenceIterationInfo.elementType);
9267+
if (!optPatternType->isEqual(nextResultType)) {
9268+
ASTContext &ctx = cs.getASTContext();
9269+
OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr(
9270+
stmt->getInLoc(), nextResultType->getOptionalObjectType(),
9271+
/*isPlaceholder=*/true);
9272+
Expr *convertElementExpr = elementExpr;
9273+
if (TypeChecker::typeCheckExpression(
9274+
convertElementExpr, dc,
9275+
/*contextualInfo=*/
9276+
{sequenceIterationInfo.initType, CTP_CoerceOperand})
9277+
.isNull()) {
9278+
return llvm::None;
9279+
}
9280+
elementExpr->setIsPlaceholder(false);
9281+
stmt->setElementExpr(elementExpr);
9282+
stmt->setConvertElementExpr(convertElementExpr);
92779283
}
9278-
elementExpr->setIsPlaceholder(false);
9279-
stmt->setElementExpr(elementExpr);
9280-
stmt->setConvertElementExpr(convertElementExpr);
9281-
}
92829284

9283-
// Get the conformance of the sequence type to the Sequence protocol.
9284-
{
9285+
// Get the conformance of the sequence type to the Sequence protocol.
92859286
auto sequenceProto = TypeChecker::getProtocol(
92869287
cs.getASTContext(), stmt->getForLoc(),
92879288
stmt->getAwaitLoc().isValid() ? KnownProtocolKind::AsyncSequence
92889289
: KnownProtocolKind::Sequence);
92899290

9290-
auto type = forEachStmtInfo.sequenceType->getRValueType();
9291+
auto type = sequenceIterationInfo.sequenceType->getRValueType();
92919292
if (type->isExistentialType()) {
92929293
auto *contextualLoc = solution.getConstraintLocator(
92939294
parsedSequence, LocatorPathElt::ContextualType(CTP_ForEachSequence));
@@ -9300,6 +9301,48 @@ static llvm::Optional<SyntacticElementTarget> applySolutionToForEachStmt(
93009301
stmt->setSequenceConformance(sequenceConformance);
93019302
}
93029303

9304+
if (forEachStmtInfo.isa<PackIterationInfo>()) {
9305+
auto packIterationInfo = *forEachStmtInfo.dyn_cast<PackIterationInfo>();
9306+
9307+
// First, let's apply the solution to the expression.
9308+
auto makeSequenceTarget = *cs.getTargetFor(parsedSequence);
9309+
auto rewrittenTarget = rewriteTarget(makeSequenceTarget);
9310+
if (!rewrittenTarget)
9311+
return llvm::None;
9312+
9313+
// Coerce the pattern to the element type.
9314+
{
9315+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
9316+
options |= TypeResolutionFlags::OverrideType;
9317+
9318+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9319+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9320+
};
9321+
9322+
// Apply the solution to the pattern as well.
9323+
auto contextualPattern = target.getContextualPattern();
9324+
auto coercedPattern = TypeChecker::coercePatternToType(
9325+
contextualPattern, packIterationInfo.expansion->getType(), options,
9326+
tryRewritePattern);
9327+
if (!coercedPattern)
9328+
return llvm::None;
9329+
9330+
stmt->setPattern(coercedPattern);
9331+
resultTarget.setPattern(coercedPattern);
9332+
}
9333+
9334+
// Apply the solution to the filtering condition, if there is one.
9335+
if (auto *whereExpr = stmt->getWhere()) {
9336+
auto whereTarget = *cs.getTargetFor(whereExpr);
9337+
9338+
auto rewrittenTarget = rewriteTarget(whereTarget);
9339+
if (!rewrittenTarget)
9340+
return llvm::None;
9341+
9342+
stmt->setWhere(rewrittenTarget->getAsExpr());
9343+
}
9344+
}
9345+
93039346
return resultTarget;
93049347
}
93059348

0 commit comments

Comments
 (0)