Skip to content

[SE-0408] Enable nested iteration #70196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,10 @@ class Solution {
llvm::MapVector<PackElementExpr *, PackExpansionExpr *>
PackEnvironments;

/// The outer pack element generic environment to use when dealing with nested
/// pack iteration (see \c getPackElementEnvironment).
llvm::SmallVector<GenericEnvironment *> PackElementGenericEnvironments;

/// The locators of \c Defaultable constraints whose defaults were used.
llvm::DenseSet<ConstraintLocator *> DefaultedConstraints;

Expand Down Expand Up @@ -2344,6 +2348,8 @@ class ConstraintSystem {
llvm::SmallMapVector<PackElementExpr *, PackExpansionExpr *, 2>
PackEnvironments;

llvm::SmallVector<GenericEnvironment *, 4> PackElementGenericEnvironments;

/// The set of functions that have been transformed by a result builder.
llvm::MapVector<AnyFunctionRef, AppliedBuilderTransform>
resultBuilderTransformed;
Expand Down Expand Up @@ -2833,6 +2839,9 @@ class ConstraintSystem {
/// The length of \c PackEnvironments.
unsigned numPackEnvironments;

/// The length of \c PackElementGenericEnvironments.
unsigned numPackElementGenericEnvironments;

/// The length of \c DefaultedConstraints.
unsigned numDefaultedConstraints;

Expand Down
17 changes: 13 additions & 4 deletions include/swift/Sema/SyntacticElementTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class SyntacticElementTarget {
DeclContext *dc;
Pattern *pattern;
bool ignoreWhereClause;
GenericEnvironment *packElementEnv;
ForEachStmtInfo info;
} forEachStmt;

Expand Down Expand Up @@ -239,11 +240,13 @@ class SyntacticElementTarget {
}

SyntacticElementTarget(ForEachStmt *stmt, DeclContext *dc,
bool ignoreWhereClause)
bool ignoreWhereClause,
GenericEnvironment *packElementEnv)
: kind(Kind::forEachStmt) {
forEachStmt.stmt = stmt;
forEachStmt.dc = dc;
forEachStmt.ignoreWhereClause = ignoreWhereClause;
forEachStmt.packElementEnv = packElementEnv;
}

/// Form a target for the initialization of a pattern from an expression.
Expand All @@ -259,9 +262,10 @@ class SyntacticElementTarget {
unsigned patternBindingIndex, bool bindPatternVarsOneWay);

/// Form a target for a for-in loop.
static SyntacticElementTarget forForEachStmt(ForEachStmt *stmt,
DeclContext *dc,
bool ignoreWhereClause = false);
static SyntacticElementTarget
forForEachStmt(ForEachStmt *stmt, DeclContext *dc,
bool ignoreWhereClause = false,
GenericEnvironment *packElementEnv = nullptr);

/// Form a target for a property with an attached property wrapper that is
/// initialized out-of-line.
Expand Down Expand Up @@ -536,6 +540,11 @@ class SyntacticElementTarget {
return forEachStmt.ignoreWhereClause;
}

GenericEnvironment *getPackElementEnv() const {
assert(isForEachStmt());
return forEachStmt.packElementEnv;
}

const ForEachStmtInfo &getForEachStmtInfo() const {
assert(isForEachStmt());
return forEachStmt.info;
Expand Down
7 changes: 5 additions & 2 deletions lib/AST/GenericEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,10 +734,11 @@ GenericEnvironment::mapElementTypeIntoPackContext(Type type) const {

type = type->mapTypeOutOfContext();

auto interfaceType = element->getInterfaceType();

llvm::SmallDenseMap<GenericParamKey, GenericTypeParamType *>
packParamForElement;
auto elementDepth =
sig.getInnermostGenericParams().front()->getDepth() + 1;
auto elementDepth = interfaceType->getRootGenericParam()->getDepth();

for (auto *genericParam : sig.getGenericParams()) {
if (!genericParam->isParameterPack())
Expand Down Expand Up @@ -792,6 +793,8 @@ Type BuildForwardingSubstitutions::operator()(SubstitutableType *type) const {
auto param = type->castTo<GenericTypeParamType>();
if (!param->isParameterPack())
return resultType;
if (resultType->is<PackType>())
return resultType;
return PackType::getSingletonPackExpansion(resultType);
}
return Type();
Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4912,6 +4912,12 @@ bool ConstraintSystem::generateConstraints(
}

case SyntacticElementTarget::Kind::forEachStmt: {

// Cache the outer generic environment, if it exists.
if (target.getPackElementEnv()) {
PackElementGenericEnvironments.push_back(target.getPackElementEnv());
}

// For a for-each statement, generate constraints for the pattern, where
// clause, and sequence traversal.
auto resultTarget = generateForEachStmtConstraints(*this, target);
Expand Down
14 changes: 14 additions & 0 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ Solution ConstraintSystem::finalize() {
for (const auto &packEnv : PackEnvironments)
solution.PackEnvironments.insert(packEnv);

for (const auto &packEltGenericEnv : PackElementGenericEnvironments)
solution.PackElementGenericEnvironments.push_back(packEltGenericEnv);

return solution;
}

Expand Down Expand Up @@ -316,6 +319,12 @@ void ConstraintSystem::applySolution(const Solution &solution) {
PackEnvironments.insert(packEnvironment);
}

// Register the solutions's pack element generic environments.
for (auto &packElementGenericEnvironment :
solution.PackElementGenericEnvironments) {
PackElementGenericEnvironments.push_back(packElementGenericEnvironment);
}

// Register the defaulted type variables.
DefaultedConstraints.insert(solution.DefaultedConstraints.begin(),
solution.DefaultedConstraints.end());
Expand Down Expand Up @@ -647,6 +656,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numOpenedPackExpansionTypes = cs.OpenedPackExpansionTypes.size();
numPackExpansionEnvironments = cs.PackExpansionEnvironments.size();
numPackEnvironments = cs.PackEnvironments.size();
numPackElementGenericEnvironments = cs.PackElementGenericEnvironments.size();
numDefaultedConstraints = cs.DefaultedConstraints.size();
numAddedNodeTypes = cs.addedNodeTypes.size();
numAddedKeyPathComponentTypes = cs.addedKeyPathComponentTypes.size();
Expand Down Expand Up @@ -736,6 +746,10 @@ ConstraintSystem::SolverScope::~SolverScope() {
// Remove any pack environments.
truncate(cs.PackEnvironments, numPackEnvironments);

// Remove any pack element generic environments.
truncate(cs.PackElementGenericEnvironments,
numPackElementGenericEnvironments);

// Remove any defaulted type variables.
truncate(cs.DefaultedConstraints, numDefaultedConstraints);

Expand Down
7 changes: 5 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,11 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator,
shapeClass->mapTypeOutOfContext()->getCanonicalType());

auto &ctx = getASTContext();
auto *contextEnv = PackElementGenericEnvironments.empty()
? DC->getGenericEnvironmentOfContext()
: PackElementGenericEnvironments.back();
auto elementSig = ctx.getOpenedElementSignature(
DC->getGenericSignatureOfContext().getCanonicalSignature(), shapeParam);
auto *contextEnv = DC->getGenericEnvironmentOfContext();
contextEnv->getGenericSignature().getCanonicalSignature(), shapeParam);
auto contextSubs = contextEnv->getForwardingSubstitutionMap();
return GenericEnvironment::forOpenedElement(elementSig, uuidAndShape.first,
shapeParam, contextSubs);
Expand Down Expand Up @@ -4403,6 +4405,7 @@ size_t Solution::getTotalMemory() const {
OpenedPackExpansionTypes.getMemorySize() +
PackExpansionEnvironments.getMemorySize() +
size_in_bytes(PackEnvironments) +
PackElementGenericEnvironments.size() +
(DefaultedConstraints.size() * sizeof(void *)) +
ImplicitCallAsFunctionRoots.getMemorySize() +
nodeTypes.getMemorySize() +
Expand Down
5 changes: 3 additions & 2 deletions lib/Sema/SyntacticElementTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ SyntacticElementTarget SyntacticElementTarget::forInitialization(

SyntacticElementTarget
SyntacticElementTarget::forForEachStmt(ForEachStmt *stmt, DeclContext *dc,
bool ignoreWhereClause) {
SyntacticElementTarget target(stmt, dc, ignoreWhereClause);
bool ignoreWhereClause,
GenericEnvironment *packElementEnv) {
SyntacticElementTarget target(stmt, dc, ignoreWhereClause, packElementEnv);
return target;
}

Expand Down
6 changes: 4 additions & 2 deletions lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
return hadError;
}

bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt,
GenericEnvironment *packElementEnv) {
auto &Context = dc->getASTContext();
FrontendStatsTracer statsTracer(Context.Stats, "typecheck-for-each", stmt);
PrettyStackTraceStmt stackTrace(Context, "type-checking-for-each", stmt);
Expand All @@ -912,7 +913,8 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
return true;
};

auto target = SyntacticElementTarget::forForEachStmt(stmt, dc);
auto target = SyntacticElementTarget::forForEachStmt(
stmt, dc, /*ignoreWhereClause=*/false, packElementEnv);
if (!typeCheckTarget(target))
return failed();

Expand Down
17 changes: 15 additions & 2 deletions lib/Sema/TypeCheckStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {

StmtChecker(DeclContext *DC) : Ctx(DC->getASTContext()), DC(DC) { }

llvm::SmallVector<GenericEnvironment *, 4> genericSigStack;

//===--------------------------------------------------------------------===//
// Helper Functions.
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1434,17 +1436,28 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
}

Stmt *visitForEachStmt(ForEachStmt *S) {
if (TypeChecker::typeCheckForEachBinding(DC, S))
GenericEnvironment *genericSignature =
genericSigStack.empty() ? nullptr : genericSigStack.back();

if (TypeChecker::typeCheckForEachBinding(DC, S, genericSignature))
return nullptr;

// Type-check the body of the loop.
auto sourceFile = DC->getParentSourceFile();
checkLabeledStmtShadowing(getASTContext(), sourceFile, S);

BraceStmt *Body = S->getBody();

if (auto packExpansion =
dyn_cast<PackExpansionExpr>(S->getParsedSequence()))
genericSigStack.push_back(packExpansion->getGenericEnvironment());

typeCheckStmt(Body);
S->setBody(Body);


if (isa<PackExpansionExpr>(S->getParsedSequence()))
genericSigStack.pop_back();

return S;
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ bool typeCheckPatternBinding(PatternBindingDecl *PBD, unsigned patternNumber,
/// Type-check a for-each loop's pattern binding and sequence together.
///
/// \returns true if a failure occurred.
bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt);
bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt,
GenericEnvironment *packElementEnv);

/// Compute the set of captures for the given function or closure.
void computeCaptures(AnyFunctionRef AFR);
Expand Down
10 changes: 10 additions & 0 deletions test/stmt/foreach.swift
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,14 @@ do {
// expected-error@-1 {{'where' clause in pack iteration is not supported}}
}
}

func nested<each T, each U>(value: repeat each T, value1: repeat each U) {
for e1 in repeat each value {
for _ in [] {}
for e2 in repeat each value1 {
let y = e1 // Ok
}
let x = e1 // Ok
}
}
}