Skip to content

Commit b6d0afb

Browse files
authored
Merge pull request #67594 from simanerush/simanerush/pack-iteration-impl
[SE-0408] Enable Pack Iteration
2 parents 452c624 + 48cb330 commit b6d0afb

26 files changed

+839
-184
lines changed

include/swift/AST/Decl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5863,6 +5863,7 @@ enum class PropertyWrapperSynthesizedPropertyKind {
58635863
class VarDecl : public AbstractStorageDecl {
58645864
friend class NamingPatternRequest;
58655865
NamedPattern *NamingPattern = nullptr;
5866+
GenericEnvironment *OpenedElementEnvironment = nullptr;
58665867

58675868
public:
58685869
enum class Introducer : uint8_t {
@@ -5990,6 +5991,13 @@ class VarDecl : public AbstractStorageDecl {
59905991
NamedPattern *getNamingPattern() const;
59915992
void setNamingPattern(NamedPattern *Pat);
59925993

5994+
GenericEnvironment *getOpenedElementEnvironment() const {
5995+
return OpenedElementEnvironment;
5996+
}
5997+
void setOpenedElementEnvironment(GenericEnvironment *Env) {
5998+
OpenedElementEnvironment = Env;
5999+
}
6000+
59936001
/// If this is a VarDecl that does not belong to a CaseLabelItem's pattern,
59946002
/// return this. Otherwise, this VarDecl must belong to a CaseStmt's
59956003
/// CaseLabelItem. In that case, return the first case label item of the first

include/swift/AST/DiagnosticsSema.def

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5859,8 +5859,8 @@ ERROR(expansion_not_allowed,none,
58595859
"pack expansion %0 can only appear in a function parameter list, "
58605860
"tuple element, or generic argument of a variadic type", (Type))
58615861
ERROR(expansion_expr_not_allowed,none,
5862-
"value pack expansion can only appear inside a function argument list "
5863-
"or tuple element", ())
5862+
"value pack expansion can only appear inside a function argument list, "
5863+
"tuple element, or as the expression of a for-in loop", ())
58645864
ERROR(invalid_expansion_argument,none,
58655865
"cannot pass value pack expansion to non-pack parameter of type %0",
58665866
(Type))
@@ -7730,5 +7730,12 @@ ERROR(referencebindings_binding_must_be_to_lvalue,none,
77307730
ERROR(result_depends_on_no_result,none,
77317731
"Incorrect use of %0 with no result", (StringRef))
77327732

7733+
//------------------------------------------------------------------------------
7734+
// MARK: Pack Iteration Diagnostics
7735+
//------------------------------------------------------------------------------
7736+
7737+
ERROR(pack_iteration_where_clause_not_supported, none,
7738+
"'where' clause in pack iteration is not supported", ())
7739+
77337740
#define UNDEFINE_DIAGNOSTIC_MACROS
77347741
#include "DefineDiagnosticMacros.h"

include/swift/Sema/CSFix.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ enum class FixKind : uint8_t {
440440
/// Allow pack expansion expressions in a context that does not support them.
441441
AllowInvalidPackExpansion,
442442

443+
/// Ignore `where` clause in a for-in loop with a pack expansion expression.
444+
IgnoreWhereClauseInPackIteration,
445+
443446
/// Allow a pack expansion parameter of N elements to be matched
444447
/// with a single tuple literal argument of the same arity.
445448
DestructureTupleToMatchPackExpansionParameter,
@@ -2227,6 +2230,26 @@ class AllowInvalidPackExpansion final : public ConstraintFix {
22272230
}
22282231
};
22292232

2233+
class IgnoreWhereClauseInPackIteration final : public ConstraintFix {
2234+
IgnoreWhereClauseInPackIteration(ConstraintSystem &cs,
2235+
ConstraintLocator *locator)
2236+
: ConstraintFix(cs, FixKind::IgnoreWhereClauseInPackIteration, locator) {}
2237+
2238+
public:
2239+
std::string getName() const override {
2240+
return "ignore where clause in pack iteration";
2241+
}
2242+
2243+
bool diagnose(const Solution &solution, bool asNote = false) const override;
2244+
2245+
static IgnoreWhereClauseInPackIteration *create(ConstraintSystem &cs,
2246+
ConstraintLocator *locator);
2247+
2248+
static bool classof(const ConstraintFix *fix) {
2249+
return fix->getKind() == FixKind::IgnoreWhereClauseInPackIteration;
2250+
}
2251+
};
2252+
22302253
class CollectionElementContextualMismatch final
22312254
: public ContextualMismatch,
22322255
private llvm::TrailingObjects<CollectionElementContextualMismatch,

include/swift/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3765,6 +3765,11 @@ class ConstraintSystem {
37653765
RememberChoice_t rememberChoice,
37663766
ConstraintLocatorBuilder locator,
37673767
ConstraintFix *compatFix = nullptr);
3768+
3769+
/// Add a materialize constraint for a pack expansion.
3770+
TypeVariableType *
3771+
addMaterializePackExpansionConstraint(Type patternType,
3772+
ConstraintLocatorBuilder locator);
37683773

37693774
/// Add a disjunction constraint.
37703775
void

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@
2222
#include "swift/AST/Pattern.h"
2323
#include "swift/AST/Stmt.h"
2424
#include "swift/AST/TypeLoc.h"
25+
#include "swift/Basic/TaggedUnion.h"
2526
#include "swift/Sema/ConstraintLocator.h"
2627
#include "swift/Sema/ContextualTypeInfo.h"
2728

2829
namespace swift {
2930

3031
namespace constraints {
31-
/// Describes information about a for-each loop that needs to be tracked
32-
/// within the constraint system.
33-
struct ForEachStmtInfo {
32+
/// Describes information specific to a sequence
33+
/// in a for-each loop.
34+
struct SequenceIterationInfo {
3435
/// The type of the sequence.
3536
Type sequenceType;
3637

@@ -47,6 +48,19 @@ struct ForEachStmtInfo {
4748
Expr *nextCall;
4849
};
4950

51+
/// Describes information specific to a pack expansion expression
52+
/// in a for-each loop.
53+
struct PackIterationInfo {
54+
/// The type of the pattern that matches the elements.
55+
Type patternType;
56+
};
57+
58+
/// Describes information about a for-each loop that needs to be tracked
59+
/// within the constraint system.
60+
struct ForEachStmtInfo : TaggedUnion<SequenceIterationInfo, PackIterationInfo> {
61+
using TaggedUnion::TaggedUnion;
62+
};
63+
5064
/// Describes the target to which a constraint system's solution can be
5165
/// applied.
5266
class SyntacticElementTarget {

lib/AST/ASTVerifier.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,13 @@ class Verifier : public ASTWalker {
803803
if (!shouldVerify(cast<Stmt>(S)))
804804
return false;
805805

806+
if (auto *expansion =
807+
dyn_cast<PackExpansionExpr>(S->getParsedSequence())) {
808+
if (!shouldVerify(expansion)) {
809+
return false;
810+
}
811+
}
812+
806813
if (!S->getElementExpr())
807814
return true;
808815

@@ -812,6 +819,11 @@ class Verifier : public ASTWalker {
812819
}
813820

814821
void cleanup(ForEachStmt *S) {
822+
if (auto *expansion =
823+
dyn_cast<PackExpansionExpr>(S->getParsedSequence())) {
824+
cleanup(expansion);
825+
}
826+
815827
if (!S->getElementExpr())
816828
return;
817829

@@ -2632,6 +2644,16 @@ class Verifier : public ASTWalker {
26322644
abort();
26332645
}
26342646

2647+
// If we are performing pack iteration, variables have to carry the
2648+
// generic environment. Catching the missing environment here will prevent
2649+
// the code from being lowered.
2650+
if (var->getTypeInContext()->is<ErrorType>()) {
2651+
Out << "VarDecl is missing a Generic Environment: ";
2652+
var->getInterfaceType().print(Out);
2653+
Out << "\n";
2654+
abort();
2655+
}
2656+
26352657
// The fact that this is *directly* be a reference storage type
26362658
// cuts the code down quite a bit in getTypeOfReference.
26372659
if (var->getAttrs().hasAttribute<ReferenceOwnershipAttr>() !=

lib/AST/ASTWalker.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,11 +1905,9 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
19051905
//
19061906
// If for-in is already type-checked, the type-checked version
19071907
// of the sequence is going to be visited as part of `iteratorVar`.
1908-
if (S->getTypeCheckedSequence()) {
1909-
if (auto IteratorVar = S->getIteratorVar()) {
1910-
if (doIt(IteratorVar))
1911-
return nullptr;
1912-
}
1908+
if (auto IteratorVar = S->getIteratorVar()) {
1909+
if (doIt(IteratorVar))
1910+
return nullptr;
19131911

19141912
if (auto NextCall = S->getNextCall()) {
19151913
if ((NextCall = doIt(NextCall)))

lib/AST/Decl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7149,6 +7149,11 @@ VarDecl::VarDecl(DeclKind kind, bool isStatic, VarDecl::Introducer introducer,
71497149
}
71507150

71517151
Type VarDecl::getTypeInContext() const {
7152+
// If we are performing pack iteration, use the generic environment of the
7153+
// pack expansion expression to get the right context of a local variable.
7154+
if (auto *env = getOpenedElementEnvironment())
7155+
return GenericEnvironment::mapTypeIntoContext(env, getInterfaceType());
7156+
71527157
return getDeclContext()->mapTypeIntoContext(getInterfaceType());
71537158
}
71547159

lib/AST/GenericSignature.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,16 @@ void GenericSignatureImpl::forEachParam(
105105

106106
for (auto req : getRequirements()) {
107107
GenericTypeParamType *gp;
108+
bool isCanonical = false;
108109
switch (req.getKind()) {
109110
case RequirementKind::SameType: {
111+
if (req.getSecondType()->isParameterPack() !=
112+
req.getFirstType()->isParameterPack()) {
113+
// This is a same-element requirement, which does not make
114+
// type parameters non-canonical.
115+
isCanonical = true;
116+
}
117+
110118
if (auto secondGP = req.getSecondType()->getAs<GenericTypeParamType>()) {
111119
// If two generic parameters are same-typed, then the right-hand one
112120
// is non-canonical.
@@ -136,7 +144,7 @@ void GenericSignatureImpl::forEachParam(
136144
}
137145

138146
unsigned index = GenericParamKey(gp).findIndexIn(genericParams);
139-
genericParamsAreCanonical[index] = false;
147+
genericParamsAreCanonical[index] = isCanonical;
140148
}
141149

142150
// Call the callback with each parameter and the result of the above analysis.

lib/AST/Stmt.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ void ForEachStmt::setPattern(Pattern *p) {
446446
}
447447

448448
Expr *ForEachStmt::getTypeCheckedSequence() const {
449+
if (auto *expansion = dyn_cast<PackExpansionExpr>(getParsedSequence()))
450+
return expansion;
451+
449452
return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
450453
}
451454

lib/IDE/CodeCompletion.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,8 @@ void CodeCompletionCallbacksImpl::addKeywords(CodeCompletionResultSink &Sink,
10701070
addSuperKeyword(Sink, CurDeclContext);
10711071
addExprKeywords(Sink, CurDeclContext);
10721072
addAnyTypeKeyword(Sink, CurDeclContext->getASTContext().TheAnyType);
1073+
if (Kind == CompletionKind::ForEachSequence)
1074+
addKeyword(Sink, "repeat", CodeCompletionKeywordKind::kw_repeat);
10731075
break;
10741076

10751077
case CompletionKind::CallArg:

lib/SILGen/SILGenFunction.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,26 +2724,24 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
27242724
///
27252725
/// This function will be called within a cleanups scope and with
27262726
/// InnermostPackExpansion set up properly for the context.
2727-
void emitDynamicPackLoop(SILLocation loc,
2728-
CanPackType formalPackType,
2729-
unsigned componentIndex,
2730-
SILValue startingAfterIndexWithinComponent,
2731-
SILValue limitWithinComponent,
2732-
GenericEnvironment *openedElementEnv,
2733-
bool reverse,
2734-
llvm::function_ref<void(SILValue indexWithinComponent,
2735-
SILValue packExpansionIndex,
2736-
SILValue packIndex)> emitBody);
2727+
void emitDynamicPackLoop(
2728+
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
2729+
SILValue startingAfterIndexWithinComponent, SILValue limitWithinComponent,
2730+
GenericEnvironment *openedElementEnv, bool reverse,
2731+
llvm::function_ref<void(SILValue indexWithinComponent,
2732+
SILValue packExpansionIndex, SILValue packIndex)>
2733+
emitBody,
2734+
SILBasicBlock *loopLatch = nullptr);
27372735

27382736
/// A convenience version of dynamic pack loop that visits an entire
27392737
/// pack expansion component in forward order.
2740-
void emitDynamicPackLoop(SILLocation loc,
2741-
CanPackType formalPackType,
2742-
unsigned componentIndex,
2743-
GenericEnvironment *openedElementEnv,
2744-
llvm::function_ref<void(SILValue indexWithinComponent,
2745-
SILValue packExpansionIndex,
2746-
SILValue packIndex)> emitBody);
2738+
void emitDynamicPackLoop(
2739+
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
2740+
GenericEnvironment *openedElementEnv,
2741+
llvm::function_ref<void(SILValue indexWithinComponent,
2742+
SILValue packExpansionIndex, SILValue packIndex)>
2743+
emitBody,
2744+
SILBasicBlock *loopLatch = nullptr);
27472745

27482746
/// Emit a transform on each element of a pack-expansion component
27492747
/// of a pack, write the result into a pack-expansion component of

lib/SILGen/SILGenPack.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -662,28 +662,27 @@ void SILGenFunction::projectTupleElementsToPack(SILLocation loc,
662662
});
663663
}
664664

665-
void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
666-
CanPackType formalPackType,
667-
unsigned componentIndex,
668-
GenericEnvironment *openedElementEnv,
669-
llvm::function_ref<void(SILValue indexWithinComponent,
670-
SILValue packExpansionIndex,
671-
SILValue packIndex)> emitBody) {
665+
void SILGenFunction::emitDynamicPackLoop(
666+
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
667+
GenericEnvironment *openedElementEnv,
668+
llvm::function_ref<void(SILValue indexWithinComponent,
669+
SILValue packExpansionIndex, SILValue packIndex)>
670+
emitBody,
671+
SILBasicBlock *loopLatch) {
672672
return emitDynamicPackLoop(loc, formalPackType, componentIndex,
673673
/*startAfter*/ SILValue(), /*limit*/ SILValue(),
674-
openedElementEnv, /*reverse*/false, emitBody);
674+
openedElementEnv, /*reverse*/ false, emitBody,
675+
loopLatch);
675676
}
676677

677-
void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
678-
CanPackType formalPackType,
679-
unsigned componentIndex,
680-
SILValue startingAfterIndexInComponent,
681-
SILValue limitWithinComponent,
682-
GenericEnvironment *openedElementEnv,
683-
bool reverse,
684-
llvm::function_ref<void(SILValue indexWithinComponent,
685-
SILValue packExpansionIndex,
686-
SILValue packIndex)> emitBody) {
678+
void SILGenFunction::emitDynamicPackLoop(
679+
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
680+
SILValue startingAfterIndexInComponent, SILValue limitWithinComponent,
681+
GenericEnvironment *openedElementEnv, bool reverse,
682+
llvm::function_ref<void(SILValue indexWithinComponent,
683+
SILValue packExpansionIndex, SILValue packIndex)>
684+
emitBody,
685+
SILBasicBlock *loopLatch) {
687686
assert(isa<PackExpansionType>(formalPackType.getElementType(componentIndex)));
688687
assert((!startingAfterIndexInComponent || !reverse) &&
689688
"cannot reverse with a starting index");
@@ -764,6 +763,7 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
764763
// the incoming index - 1 if reverse)
765764
SILValue curIndex = incomingIndex;
766765
if (reverse) {
766+
assert(!loopLatch && "Only forward iteration supported with loop latch");
767767
curIndex = B.createBuiltinBinaryFunction(loc, "sub", wordTy, wordTy,
768768
{ incomingIndex, one });
769769
}
@@ -791,6 +791,13 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
791791
{
792792
FullExpr scope(Cleanups, CleanupLocation(loc));
793793
emitBody(curIndex, packExpansionIndex, packIndex);
794+
if (loopLatch) {
795+
B.createBranch(loc, loopLatch);
796+
}
797+
}
798+
799+
if (loopLatch) {
800+
B.emitBlock(loopLatch);
794801
}
795802

796803
// The index to pass to the loop condition block (the current index + 1

0 commit comments

Comments
 (0)