Skip to content

[SE-0408] Enable Pack Iteration #67594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5855,6 +5855,7 @@ enum class PropertyWrapperSynthesizedPropertyKind {
class VarDecl : public AbstractStorageDecl {
friend class NamingPatternRequest;
NamedPattern *NamingPattern = nullptr;
GenericEnvironment *OpenedElementEnvironment = nullptr;

public:
enum class Introducer : uint8_t {
Expand Down Expand Up @@ -5982,6 +5983,13 @@ class VarDecl : public AbstractStorageDecl {
NamedPattern *getNamingPattern() const;
void setNamingPattern(NamedPattern *Pat);

GenericEnvironment *getOpenedElementEnvironment() const {
return OpenedElementEnvironment;
}
void setOpenedElementEnvironment(GenericEnvironment *Env) {
OpenedElementEnvironment = Env;
}

/// If this is a VarDecl that does not belong to a CaseLabelItem's pattern,
/// return this. Otherwise, this VarDecl must belong to a CaseStmt's
/// CaseLabelItem. In that case, return the first case label item of the first
Expand Down
11 changes: 9 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -5834,8 +5834,8 @@ ERROR(expansion_not_allowed,none,
"pack expansion %0 can only appear in a function parameter list, "
"tuple element, or generic argument of a variadic type", (Type))
ERROR(expansion_expr_not_allowed,none,
"value pack expansion can only appear inside a function argument list "
"or tuple element", ())
"value pack expansion can only appear inside a function argument list, "
"tuple element, or as the expression of a for-in loop", ())
ERROR(invalid_expansion_argument,none,
"cannot pass value pack expansion to non-pack parameter of type %0",
(Type))
Expand Down Expand Up @@ -7715,5 +7715,12 @@ ERROR(referencebindings_binding_must_be_to_lvalue,none,
ERROR(result_depends_on_no_result,none,
"Incorrect use of %0 with no result", (StringRef))

//------------------------------------------------------------------------------
// MARK: Pack Iteration Diagnostics
//------------------------------------------------------------------------------

ERROR(pack_iteration_where_clause_not_supported, none,
"'where' clause in pack iteration is not supported", ())

#define UNDEFINE_DIAGNOSTIC_MACROS
#include "DefineDiagnosticMacros.h"
23 changes: 23 additions & 0 deletions include/swift/Sema/CSFix.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ enum class FixKind : uint8_t {
/// Allow pack expansion expressions in a context that does not support them.
AllowInvalidPackExpansion,

/// Ignore `where` clause in a for-in loop with a pack expansion expression.
IgnoreWhereClauseInPackIteration,

/// Allow a pack expansion parameter of N elements to be matched
/// with a single tuple literal argument of the same arity.
DestructureTupleToMatchPackExpansionParameter,
Expand Down Expand Up @@ -2223,6 +2226,26 @@ class AllowInvalidPackExpansion final : public ConstraintFix {
}
};

class IgnoreWhereClauseInPackIteration final : public ConstraintFix {
IgnoreWhereClauseInPackIteration(ConstraintSystem &cs,
ConstraintLocator *locator)
: ConstraintFix(cs, FixKind::IgnoreWhereClauseInPackIteration, locator) {}

public:
std::string getName() const override {
return "ignore where clause in pack iteration";
}

bool diagnose(const Solution &solution, bool asNote = false) const override;

static IgnoreWhereClauseInPackIteration *create(ConstraintSystem &cs,
ConstraintLocator *locator);

static bool classof(const ConstraintFix *fix) {
return fix->getKind() == FixKind::IgnoreWhereClauseInPackIteration;
}
};

class CollectionElementContextualMismatch final
: public ContextualMismatch,
private llvm::TrailingObjects<CollectionElementContextualMismatch,
Expand Down
5 changes: 5 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -3759,6 +3759,11 @@ class ConstraintSystem {
RememberChoice_t rememberChoice,
ConstraintLocatorBuilder locator,
ConstraintFix *compatFix = nullptr);

/// Add a materialize constraint for a pack expansion.
TypeVariableType *
addMaterializePackExpansionConstraint(Type patternType,
ConstraintLocatorBuilder locator);

/// Add a disjunction constraint.
void
Expand Down
20 changes: 17 additions & 3 deletions include/swift/Sema/SyntacticElementTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@
#include "swift/AST/Pattern.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/TypeLoc.h"
#include "swift/Basic/TaggedUnion.h"
#include "swift/Sema/ConstraintLocator.h"
#include "swift/Sema/ContextualTypeInfo.h"

namespace swift {

namespace constraints {
/// Describes information about a for-each loop that needs to be tracked
/// within the constraint system.
struct ForEachStmtInfo {
/// Describes information specific to a sequence
/// in a for-each loop.
struct SequenceIterationInfo {
/// The type of the sequence.
Type sequenceType;

Expand All @@ -47,6 +48,19 @@ struct ForEachStmtInfo {
Expr *nextCall;
};

/// Describes information specific to a pack expansion expression
/// in a for-each loop.
struct PackIterationInfo {
/// The type of the pattern that matches the elements.
Type patternType;
};

/// Describes information about a for-each loop that needs to be tracked
/// within the constraint system.
struct ForEachStmtInfo : TaggedUnion<SequenceIterationInfo, PackIterationInfo> {
using TaggedUnion::TaggedUnion;
};

/// Describes the target to which a constraint system's solution can be
/// applied.
class SyntacticElementTarget {
Expand Down
22 changes: 22 additions & 0 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,13 @@ class Verifier : public ASTWalker {
if (!shouldVerify(cast<Stmt>(S)))
return false;

if (auto *expansion =
dyn_cast<PackExpansionExpr>(S->getParsedSequence())) {
if (!shouldVerify(expansion)) {
return false;
}
}

if (!S->getElementExpr())
return true;

Expand All @@ -804,6 +811,11 @@ class Verifier : public ASTWalker {
}

void cleanup(ForEachStmt *S) {
if (auto *expansion =
dyn_cast<PackExpansionExpr>(S->getParsedSequence())) {
cleanup(expansion);
}

if (!S->getElementExpr())
return;

Expand Down Expand Up @@ -2605,6 +2617,16 @@ class Verifier : public ASTWalker {
abort();
}

// If we are performing pack iteration, variables have to carry the
// generic environment. Catching the missing environment here will prevent
// the code from being lowered.
if (var->getTypeInContext()->is<ErrorType>()) {
Out << "VarDecl is missing a Generic Environment: ";
var->getInterfaceType().print(Out);
Out << "\n";
abort();
}

// The fact that this is *directly* be a reference storage type
// cuts the code down quite a bit in getTypeOfReference.
if (var->getAttrs().hasAttribute<ReferenceOwnershipAttr>() !=
Expand Down
8 changes: 3 additions & 5 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1905,11 +1905,9 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
//
// If for-in is already type-checked, the type-checked version
// of the sequence is going to be visited as part of `iteratorVar`.
if (S->getTypeCheckedSequence()) {
if (auto IteratorVar = S->getIteratorVar()) {
if (doIt(IteratorVar))
return nullptr;
}
if (auto IteratorVar = S->getIteratorVar()) {
if (doIt(IteratorVar))
return nullptr;

if (auto NextCall = S->getNextCall()) {
if ((NextCall = doIt(NextCall)))
Expand Down
5 changes: 5 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7135,6 +7135,11 @@ VarDecl::VarDecl(DeclKind kind, bool isStatic, VarDecl::Introducer introducer,
}

Type VarDecl::getTypeInContext() const {
// If we are performing pack iteration, use the generic environment of the
// pack expansion expression to get the right context of a local variable.
if (auto *env = getOpenedElementEnvironment())
return GenericEnvironment::mapTypeIntoContext(env, getInterfaceType());

return getDeclContext()->mapTypeIntoContext(getInterfaceType());
}

Expand Down
10 changes: 9 additions & 1 deletion lib/AST/GenericSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,16 @@ void GenericSignatureImpl::forEachParam(

for (auto req : getRequirements()) {
GenericTypeParamType *gp;
bool isCanonical = false;
switch (req.getKind()) {
case RequirementKind::SameType: {
if (req.getSecondType()->isParameterPack() !=
req.getFirstType()->isParameterPack()) {
// This is a same-element requirement, which does not make
// type parameters non-canonical.
isCanonical = true;
}

if (auto secondGP = req.getSecondType()->getAs<GenericTypeParamType>()) {
// If two generic parameters are same-typed, then the right-hand one
// is non-canonical.
Expand Down Expand Up @@ -136,7 +144,7 @@ void GenericSignatureImpl::forEachParam(
}

unsigned index = GenericParamKey(gp).findIndexIn(genericParams);
genericParamsAreCanonical[index] = false;
genericParamsAreCanonical[index] = isCanonical;
}

// Call the callback with each parameter and the result of the above analysis.
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ void ForEachStmt::setPattern(Pattern *p) {
}

Expr *ForEachStmt::getTypeCheckedSequence() const {
if (auto *expansion = dyn_cast<PackExpansionExpr>(getParsedSequence()))
return expansion;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to do more checking here or add a bit to the statement to make sure that we are only returning non-null pack expansion expressions iff type-checking was successful. For regular for-in statements type-checker sets iteratorVar, for pack iteration I think we just need to tag the expression itself (no harm if that applies for both packs and regular sequences too!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we might need to rethink the purpose of this method with this feature in mind. The bug where pack iteration was not working in closures had to partially deal with the fact that whoever calls it assumes that the type-checking was successful and we are dealing with a sequence if the method is non-null: 90ca95c#diff-f7b20ead68204a38f1ecf3cd2202f98fbcbfc193e117458b8ce6e612cb8855c7R1893. When we are dealing with a pack expansion, the assumed "sequence" will be non-null (but containing an expansion expr) and we will go into the sequence code path, which is undesired. So the only way to use this method "right" is to check if casting the result into an expansion expression is non-null to make sure that the sequence is null... Which is not obvious at all and we need to fix this. I think I'll put a TODO on it for now


return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
}

Expand Down
2 changes: 2 additions & 0 deletions lib/IDE/CodeCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ void CodeCompletionCallbacksImpl::addKeywords(CodeCompletionResultSink &Sink,
addSuperKeyword(Sink, CurDeclContext);
addExprKeywords(Sink, CurDeclContext);
addAnyTypeKeyword(Sink, CurDeclContext->getASTContext().TheAnyType);
if (Kind == CompletionKind::ForEachSequence)
addKeyword(Sink, "repeat", CodeCompletionKeywordKind::kw_repeat);
break;

case CompletionKind::CallArg:
Expand Down
32 changes: 15 additions & 17 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2724,26 +2724,24 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
///
/// This function will be called within a cleanups scope and with
/// InnermostPackExpansion set up properly for the context.
void emitDynamicPackLoop(SILLocation loc,
CanPackType formalPackType,
unsigned componentIndex,
SILValue startingAfterIndexWithinComponent,
SILValue limitWithinComponent,
GenericEnvironment *openedElementEnv,
bool reverse,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex,
SILValue packIndex)> emitBody);
void emitDynamicPackLoop(
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
SILValue startingAfterIndexWithinComponent, SILValue limitWithinComponent,
GenericEnvironment *openedElementEnv, bool reverse,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex, SILValue packIndex)>
emitBody,
SILBasicBlock *loopLatch = nullptr);

/// A convenience version of dynamic pack loop that visits an entire
/// pack expansion component in forward order.
void emitDynamicPackLoop(SILLocation loc,
CanPackType formalPackType,
unsigned componentIndex,
GenericEnvironment *openedElementEnv,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex,
SILValue packIndex)> emitBody);
void emitDynamicPackLoop(
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
GenericEnvironment *openedElementEnv,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex, SILValue packIndex)>
emitBody,
SILBasicBlock *loopLatch = nullptr);

/// Emit a transform on each element of a pack-expansion component
/// of a pack, write the result into a pack-expansion component of
Expand Down
43 changes: 25 additions & 18 deletions lib/SILGen/SILGenPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,28 +662,27 @@ void SILGenFunction::projectTupleElementsToPack(SILLocation loc,
});
}

void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
CanPackType formalPackType,
unsigned componentIndex,
GenericEnvironment *openedElementEnv,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex,
SILValue packIndex)> emitBody) {
void SILGenFunction::emitDynamicPackLoop(
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
GenericEnvironment *openedElementEnv,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex, SILValue packIndex)>
emitBody,
SILBasicBlock *loopLatch) {
return emitDynamicPackLoop(loc, formalPackType, componentIndex,
/*startAfter*/ SILValue(), /*limit*/ SILValue(),
openedElementEnv, /*reverse*/false, emitBody);
openedElementEnv, /*reverse*/ false, emitBody,
loopLatch);
}

void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
CanPackType formalPackType,
unsigned componentIndex,
SILValue startingAfterIndexInComponent,
SILValue limitWithinComponent,
GenericEnvironment *openedElementEnv,
bool reverse,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex,
SILValue packIndex)> emitBody) {
void SILGenFunction::emitDynamicPackLoop(
SILLocation loc, CanPackType formalPackType, unsigned componentIndex,
SILValue startingAfterIndexInComponent, SILValue limitWithinComponent,
GenericEnvironment *openedElementEnv, bool reverse,
llvm::function_ref<void(SILValue indexWithinComponent,
SILValue packExpansionIndex, SILValue packIndex)>
emitBody,
SILBasicBlock *loopLatch) {
assert(isa<PackExpansionType>(formalPackType.getElementType(componentIndex)));
assert((!startingAfterIndexInComponent || !reverse) &&
"cannot reverse with a starting index");
Expand Down Expand Up @@ -764,6 +763,7 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
// the incoming index - 1 if reverse)
SILValue curIndex = incomingIndex;
if (reverse) {
assert(!loopLatch && "Only forward iteration supported with loop latch");
curIndex = B.createBuiltinBinaryFunction(loc, "sub", wordTy, wordTy,
{ incomingIndex, one });
}
Expand Down Expand Up @@ -791,6 +791,13 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc,
{
FullExpr scope(Cleanups, CleanupLocation(loc));
emitBody(curIndex, packExpansionIndex, packIndex);
if (loopLatch) {
B.createBranch(loc, loopLatch);
}
}

if (loopLatch) {
B.emitBlock(loopLatch);
}

// The index to pass to the loop condition block (the current index + 1
Expand Down
Loading