Skip to content

Commit c69590f

Browse files
committed
[CS] Clean up pack expansion environment handling a little
- Track environments for `PackExpansionExpr` directly instead of using a locator. - Split up the querying and creation of the environment such that the mismatch logic can be done directly in CSSimplify instead of duplicating it. - Just store the environment directly instead of the shape and UUID.
1 parent 0b57ca6 commit c69590f

File tree

8 files changed

+122
-73
lines changed

8 files changed

+122
-73
lines changed

include/swift/Sema/CSTrail.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ LOCATOR_CHANGE(RecordedAppliedDisjunction, AppliedDisjunctions)
5555
LOCATOR_CHANGE(RecordedMatchCallArgumentResult, argumentMatchingChoices)
5656
LOCATOR_CHANGE(RecordedOpenedTypes, OpenedTypes)
5757
LOCATOR_CHANGE(RecordedOpenedExistentialType, OpenedExistentialTypes)
58-
LOCATOR_CHANGE(RecordedPackExpansionEnvironment, PackExpansionEnvironments)
5958
LOCATOR_CHANGE(RecordedDefaultedConstraint, DefaultedConstraints)
6059
LOCATOR_CHANGE(ResolvedOverload, ResolvedOverloads)
6160
LOCATOR_CHANGE(RecordedArgumentList, ArgumentLists)
@@ -96,6 +95,7 @@ CHANGE(AddedFix)
9695
CHANGE(AddedFixedRequirement)
9796
CHANGE(RecordedOpenedPackExpansionType)
9897
CHANGE(RecordedPackElementExpansion)
98+
CHANGE(RecordedPackExpansionEnvironment)
9999
CHANGE(RecordedNodeType)
100100
CHANGE(RecordedKeyPathComponentType)
101101
CHANGE(RecordedResultBuilderTransform)

include/swift/Sema/CSTrail.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class SolverTrail {
150150
ConstraintFix *TheFix;
151151
ConstraintLocator *TheLocator;
152152
PackExpansionType *TheExpansion;
153+
PackExpansionExpr *TheExpansionExpr;
153154
PackElementExpr *TheElement;
154155
Expr *TheExpr;
155156
Stmt *TheStmt;
@@ -215,6 +216,10 @@ class SolverTrail {
215216
/// to its parent expansion expression.
216217
static Change RecordedPackElementExpansion(PackElementExpr *packElement);
217218

219+
/// Create a change that records the GenericEnvironment for a given
220+
/// PackExpansionExpr.
221+
static Change RecordedPackExpansionEnvironment(PackExpansionExpr *expr);
222+
218223
/// Create a change that recorded an assignment of a type to an AST node.
219224
static Change RecordedNodeType(ASTNode node, Type oldType);
220225

include/swift/Sema/ConstraintSystem.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,9 +1560,9 @@ class Solution {
15601560
llvm::DenseMap<PackExpansionType *, TypeVariableType *>
15611561
OpenedPackExpansionTypes;
15621562

1563-
/// The pack expansion environment that can open pack elements for
1564-
/// a given locator.
1565-
llvm::DenseMap<ConstraintLocator *, std::pair<UUID, Type>>
1563+
/// The generic environment that can open pack elements for a given
1564+
/// pack expansion.
1565+
llvm::DenseMap<PackExpansionExpr *, GenericEnvironment *>
15661566
PackExpansionEnvironments;
15671567

15681568
/// The pack expansion expression for a given pack element.
@@ -1810,6 +1810,11 @@ class Solution {
18101810
return Type();
18111811
}
18121812

1813+
/// Retrieve the generic environment for the opened element of a given pack
1814+
/// expansion, or \c nullptr if no environment was recorded.
1815+
GenericEnvironment *
1816+
getPackExpansionEnvironment(PackExpansionExpr *expr) const;
1817+
18131818
/// For a given locator describing a function argument conversion, or a
18141819
/// constraint within an argument conversion, returns information about the
18151820
/// application of the argument to its parameter. If the locator is not
@@ -2407,7 +2412,7 @@ class ConstraintSystem {
24072412
llvm::SmallDenseMap<PackExpansionType *, TypeVariableType *, 4>
24082413
OpenedPackExpansionTypes;
24092414

2410-
llvm::SmallDenseMap<ConstraintLocator *, std::pair<UUID, Type>, 4>
2415+
llvm::SmallDenseMap<PackExpansionExpr *, GenericEnvironment *, 4>
24112416
PackExpansionEnvironments;
24122417

24132418
llvm::SmallDenseMap<PackElementExpr *, PackExpansionExpr *, 2>
@@ -3370,13 +3375,26 @@ class ConstraintSystem {
33703375
void recordOpenedExistentialType(ConstraintLocator *locator,
33713376
OpenedArchetypeType *opened);
33723377

3373-
/// Get the opened element generic environment for the given locator.
3374-
GenericEnvironment *getPackElementEnvironment(ConstraintLocator *locator,
3375-
CanType shapeClass);
3378+
/// Retrieve the generic environment for the opened element of a given pack
3379+
/// expansion, or \c nullptr if no environment was recorded yet.
3380+
GenericEnvironment *
3381+
getPackExpansionEnvironment(PackExpansionExpr *expr) const;
3382+
3383+
/// Create a new opened element generic environment for the given pack
3384+
/// expansion.
3385+
GenericEnvironment *
3386+
createPackExpansionEnvironment(PackExpansionExpr *expr,
3387+
CanGenericTypeParamType shapeParam);
33763388

33773389
/// Update PackExpansionEnvironments and record a change in the trail.
3378-
void recordPackExpansionEnvironment(ConstraintLocator *locator,
3379-
std::pair<UUID, Type> uuidAndShape);
3390+
void recordPackExpansionEnvironment(PackExpansionExpr *expr,
3391+
GenericEnvironment *env);
3392+
3393+
/// Undo the above change.
3394+
void removePackExpansionEnvironment(PackExpansionExpr *expr) {
3395+
bool erased = PackExpansionEnvironments.erase(expr);
3396+
ASSERT(erased);
3397+
}
33803398

33813399
/// Get the pack expansion expr for the given pack element.
33823400
PackExpansionExpr *

lib/Sema/CSApply.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,21 +3942,15 @@ namespace {
39423942
}
39433943

39443944
Expr *visitPackExpansionExpr(PackExpansionExpr *expr) {
3945-
simplifyExprType(expr);
3946-
39473945
// Set the opened pack element environment for this pack expansion.
3948-
auto expansionTy = cs.getType(expr)->castTo<PackExpansionType>();
3949-
auto *locator = cs.getConstraintLocator(expr);
3950-
auto *environment = cs.getPackElementEnvironment(locator,
3951-
expansionTy->getCountType()->getCanonicalType());
3952-
39533946
// Assert that we have an opened element environment, otherwise we'll get
39543947
// an ASTVerifier crash when pack archetypes or element archetypes appear
39553948
// inside the pack expansion expression.
3949+
auto *environment = solution.getPackExpansionEnvironment(expr);
39563950
assert(environment);
39573951
expr->setGenericEnvironment(environment);
39583952

3959-
return expr;
3953+
return simplifyExprType(expr);
39603954
}
39613955

39623956
Expr *visitPackElementExpr(PackElementExpr *expr) {

lib/Sema/CSSimplify.cpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9629,43 +9629,57 @@ ConstraintSystem::simplifyBindTupleOfFunctionParamsConstraint(
96299629
ConstraintSystem::SolutionKind
96309630
ConstraintSystem::matchPackElementType(Type elementType, Type patternType,
96319631
ConstraintLocatorBuilder locator) {
9632-
auto *loc = getConstraintLocator(locator);
9633-
auto shapeClass = patternType->getReducedShape();
9634-
auto *elementEnv = getPackElementEnvironment(loc, shapeClass);
9635-
9636-
// Without an opened element environment, we cannot derive the
9637-
// element binding.
9638-
if (!elementEnv) {
9632+
auto tryFix = [&](llvm::function_ref<ConstraintFix *(void)> fix) {
96399633
if (!shouldAttemptFixes())
96409634
return SolutionKind::Error;
96419635

9642-
// `each` was applied to a concrete type.
9643-
if (!shapeClass->is<PackArchetypeType>()) {
9644-
if (recordFix(AllowInvalidPackElement::create(*this, patternType, loc)))
9645-
return SolutionKind::Error;
9646-
} else {
9647-
auto envShape = PackExpansionEnvironments.find(loc);
9648-
if (envShape == PackExpansionEnvironments.end()) {
9649-
return SolutionKind::Error;
9650-
}
9651-
auto *fix = SkipSameShapeRequirement::create(
9652-
*this, envShape->second.second, shapeClass,
9653-
getConstraintLocator(loc, ConstraintLocator::PackShape));
9654-
if (recordFix(fix)) {
9655-
return SolutionKind::Error;
9656-
}
9657-
}
9636+
if (recordFix(fix()))
9637+
return SolutionKind::Error;
96589638

96599639
recordAnyTypeVarAsPotentialHole(elementType);
96609640
return SolutionKind::Solved;
9641+
};
9642+
9643+
auto *loc = getConstraintLocator(locator);
9644+
ASSERT(loc->directlyAt<PackExpansionExpr>());
9645+
auto *packExpansion = castToExpr<PackExpansionExpr>(loc->getAnchor());
9646+
9647+
ASSERT(!patternType->hasTypeVariable());
9648+
auto shapeClass = patternType->getReducedShape();
9649+
9650+
// `each` was applied to a concrete type.
9651+
if (!shapeClass->is<PackArchetypeType>()) {
9652+
return tryFix([&]() {
9653+
return AllowInvalidPackElement::create(*this, patternType, loc);
9654+
});
9655+
}
9656+
9657+
auto shapeParam = CanGenericTypeParamType(cast<GenericTypeParamType>(
9658+
shapeClass->mapTypeOutOfContext()->getCanonicalType()));
9659+
9660+
auto *genericEnv = getPackExpansionEnvironment(packExpansion);
9661+
if (genericEnv) {
9662+
if (shapeParam != genericEnv->getOpenedElementShapeClass()) {
9663+
return tryFix([&]() {
9664+
auto envShape = genericEnv->mapTypeIntoContext(
9665+
genericEnv->getOpenedElementShapeClass());
9666+
if (auto *pack = dyn_cast<PackType>(envShape))
9667+
envShape = pack->unwrapSingletonPackExpansion()->getPatternType();
9668+
9669+
return SkipSameShapeRequirement::create(
9670+
*this, envShape, shapeClass,
9671+
getConstraintLocator(loc, ConstraintLocator::PackShape));
9672+
});
9673+
}
9674+
} else {
9675+
genericEnv = createPackExpansionEnvironment(packExpansion, shapeParam);
96619676
}
96629677

96639678
auto expectedElementTy =
9664-
elementEnv->mapContextualPackTypeIntoElementContext(patternType);
9679+
genericEnv->mapContextualPackTypeIntoElementContext(patternType);
96659680
assert(!expectedElementTy->is<PackType>());
96669681

9667-
addConstraint(ConstraintKind::Equal, elementType, expectedElementTy,
9668-
locator);
9682+
addConstraint(ConstraintKind::Equal, elementType, expectedElementTy, locator);
96699683
return SolutionKind::Solved;
96709684
}
96719685

lib/Sema/CSTrail.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ SolverTrail::Change SolverTrail::Change::RecordedPackElementExpansion(
199199
return result;
200200
}
201201

202+
SolverTrail::Change
203+
SolverTrail::Change::RecordedPackExpansionEnvironment(PackExpansionExpr *expr) {
204+
Change result;
205+
result.Kind = ChangeKind::RecordedPackExpansionEnvironment;
206+
result.TheExpansionExpr = expr;
207+
return result;
208+
}
209+
202210
SolverTrail::Change
203211
SolverTrail::Change::RecordedNodeType(ASTNode node, Type oldType) {
204212
Change result;
@@ -430,6 +438,10 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const {
430438
cs.removePackElementExpansion(TheElement);
431439
break;
432440

441+
case ChangeKind::RecordedPackExpansionEnvironment:
442+
cs.removePackExpansionEnvironment(TheExpansionExpr);
443+
break;
444+
433445
case ChangeKind::RecordedNodeType:
434446
cs.restoreType(Node.Node, Node.OldType);
435447
break;
@@ -702,6 +714,12 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out,
702714
out << ")\n";
703715
break;
704716

717+
case ChangeKind::RecordedPackExpansionEnvironment:
718+
out << "(RecordedPackExpansionEnvironment ";
719+
dumpAnchor(TheExpansionExpr, &SM, out);
720+
out << ")\n";
721+
break;
722+
705723
case ChangeKind::RecordedNodeType:
706724
out << "(RecordedNodeType at ";
707725
Node.Node.getStartLoc().print(out, cs.getASTContext().SourceMgr);

lib/Sema/ConstraintSystem.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -877,44 +877,35 @@ void ConstraintSystem::recordOpenedExistentialType(
877877
}
878878

879879
GenericEnvironment *
880-
ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator,
881-
CanType shapeClass) {
882-
assert(locator->directlyAt<PackExpansionExpr>());
883-
884-
std::pair<UUID, Type> uuidAndShape;
885-
auto result = PackExpansionEnvironments.find(locator);
886-
if (result == PackExpansionEnvironments.end()) {
887-
uuidAndShape = std::make_pair(UUID::fromTime(), shapeClass);
888-
recordPackExpansionEnvironment(locator, uuidAndShape);
889-
} else {
890-
uuidAndShape = result->second;
891-
}
892-
893-
if (!shapeClass->is<PackArchetypeType>() ||
894-
!shapeClass->isEqual(uuidAndShape.second))
880+
ConstraintSystem::getPackExpansionEnvironment(PackExpansionExpr *expr) const {
881+
auto result = PackExpansionEnvironments.find(expr);
882+
if (result == PackExpansionEnvironments.end())
895883
return nullptr;
896884

897-
auto shapeParam = cast<GenericTypeParamType>(
898-
shapeClass->mapTypeOutOfContext()->getCanonicalType());
885+
return result->second;
886+
}
899887

900-
auto &ctx = getASTContext();
888+
GenericEnvironment *ConstraintSystem::createPackExpansionEnvironment(
889+
PackExpansionExpr *expr, CanGenericTypeParamType shapeParam) {
901890
auto *contextEnv = PackElementGenericEnvironments.empty()
902891
? DC->getGenericEnvironmentOfContext()
903892
: PackElementGenericEnvironments.back();
904-
auto elementSig = ctx.getOpenedElementSignature(
893+
auto elementSig = getASTContext().getOpenedElementSignature(
905894
contextEnv->getGenericSignature().getCanonicalSignature(), shapeParam);
906895
auto contextSubs = contextEnv->getForwardingSubstitutionMap();
907-
return GenericEnvironment::forOpenedElement(elementSig, uuidAndShape.first,
908-
shapeParam, contextSubs);
896+
auto *env = GenericEnvironment::forOpenedElement(elementSig, UUID::fromTime(),
897+
shapeParam, contextSubs);
898+
recordPackExpansionEnvironment(expr, env);
899+
return env;
909900
}
910901

911-
void ConstraintSystem::recordPackExpansionEnvironment(
912-
ConstraintLocator *locator, std::pair<UUID, Type> uuidAndShape) {
913-
bool inserted = PackExpansionEnvironments.insert({locator, uuidAndShape}).second;
902+
void ConstraintSystem::recordPackExpansionEnvironment(PackExpansionExpr *expr,
903+
GenericEnvironment *env) {
904+
bool inserted = PackExpansionEnvironments.insert({expr, env}).second;
914905
ASSERT(inserted);
915906

916907
if (solverState)
917-
recordChange(SolverTrail::Change::RecordedPackExpansionEnvironment(locator));
908+
recordChange(SolverTrail::Change::RecordedPackExpansionEnvironment(expr));
918909
}
919910

920911
PackExpansionExpr *
@@ -4050,6 +4041,15 @@ ASTNode ConstraintSystem::includingParentApply(ASTNode node) {
40504041
return node;
40514042
}
40524043

4044+
GenericEnvironment *
4045+
Solution::getPackExpansionEnvironment(PackExpansionExpr *expr) const {
4046+
auto iter = PackExpansionEnvironments.find(expr);
4047+
if (iter == PackExpansionEnvironments.end())
4048+
return nullptr;
4049+
4050+
return iter->second;
4051+
}
4052+
40534053
std::optional<FunctionArgApplyInfo>
40544054
Solution::getFunctionArgApplyInfo(ConstraintLocator *locator) const {
40554055
// It's only valid to use `&` in argument positions, but we need

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,11 +1632,11 @@ void ConstraintSystem::print(raw_ostream &out) const {
16321632

16331633
if (!PackExpansionEnvironments.empty()) {
16341634
out.indent(indent) << "Pack Expansion Environments:\n";
1635-
for (const auto &env : PackExpansionEnvironments) {
1635+
for (const auto &[packExpansion, env] : PackExpansionEnvironments) {
16361636
out.indent(indent + 2);
1637-
env.first->dump(&getASTContext().SourceMgr, out);
1638-
out << " = (" << env.second.first << ", "
1639-
<< env.second.second->getString(PO) << ")" << '\n';
1637+
dumpAnchor(packExpansion, &getASTContext().SourceMgr, out);
1638+
out << " = (" << env->getOpenedElementShapeClass() << ", "
1639+
<< env->getOpenedElementUUID() << ")" << '\n';
16401640
}
16411641
}
16421642

0 commit comments

Comments
 (0)