Skip to content

add tracking of pack environments for pack elements to Constraint System #67164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,10 @@ class Solution {
llvm::DenseMap<ConstraintLocator *, std::pair<UUID, Type>>
PackExpansionEnvironments;

/// The pack expansion environment that can open a given pack element.
llvm::SmallMapVector<PackElementExpr *, PackExpansionExpr *, 2>
PackEnvironments;

/// The locators of \c Defaultable constraints whose defaults were used.
llvm::SmallPtrSet<ConstraintLocator *, 2> DefaultedConstraints;

Expand Down Expand Up @@ -2251,6 +2255,9 @@ class ConstraintSystem {
llvm::SmallMapVector<ConstraintLocator *, std::pair<UUID, Type>, 4>
PackExpansionEnvironments;

llvm::SmallMapVector<PackElementExpr *, PackExpansionExpr *, 2>
PackEnvironments;

/// The set of functions that have been transformed by a result builder.
llvm::MapVector<AnyFunctionRef, AppliedBuilderTransform>
resultBuilderTransformed;
Expand Down Expand Up @@ -2737,6 +2744,9 @@ class ConstraintSystem {
/// The length of \c PackExpansionEnvironments.
unsigned numPackExpansionEnvironments;

/// The length of \c PackEnvironments.
unsigned numPackEnvironments;

/// The length of \c DefaultedConstraints.
unsigned numDefaultedConstraints;

Expand Down Expand Up @@ -3232,6 +3242,13 @@ class ConstraintSystem {
GenericEnvironment *getPackElementEnvironment(ConstraintLocator *locator,
CanType shapeClass);

/// Get the opened element generic environment for the given pack element.
PackExpansionExpr *getPackEnvironment(PackElementExpr *packElement) const;

/// Associate an opened element generic environment to a pack element.
void addPackEnvironment(PackElementExpr *packElement,
PackExpansionExpr *packExpansion);

/// Retrieve the constraint locator for the given anchor and
/// path, uniqued and automatically infer the summary flags
ConstraintLocator *
Expand Down
93 changes: 66 additions & 27 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,20 +1089,20 @@ namespace {
return outputTy;
}

Type openPackElement(Type packType, ConstraintLocator *locator) {
Type openPackElement(Type packType, ConstraintLocator *locator,
PackExpansionExpr *packElementEnvironment) {
// If 'each t' is written outside of a pack expansion expression, allow the
// type to bind to a hole. The invalid pack reference will be diagnosed when
// attempting to bind the type variable for the underlying pack reference to
// a pack type without TVO_CanBindToPack.
if (PackElementEnvironments.empty()) {
if (!packElementEnvironment) {
return CS.createTypeVariable(locator,
TVO_CanBindToHole | TVO_CanBindToNoEscape);
}

// The type of a PackElementExpr is the opened pack element archetype
// of the pack reference.
OpenPackElementType openPackElement(CS, locator,
PackElementEnvironments.back());
OpenPackElementType openPackElement(CS, locator, packElementEnvironment);
return openPackElement(packType, /*packRepr*/ nullptr);
}

Expand All @@ -1124,6 +1124,26 @@ namespace {

void addPackElementEnvironment(PackExpansionExpr *expr) {
PackElementEnvironments.push_back(expr);

SmallVector<ASTNode, 2> expandedPacks;
collectExpandedPacks(expr, expandedPacks);
for (auto pack : expandedPacks) {
if (auto *elementExpr = getAsExpr<PackElementExpr>(pack)) {
CS.addPackEnvironment(elementExpr, expr);
}
}

auto *patternLoc = CS.getConstraintLocator(
expr, ConstraintLocator::PackExpansionPattern);
auto patternType = CS.createTypeVariable(
patternLoc,
TVO_CanBindToPack | TVO_CanBindToNoEscape | TVO_CanBindToHole);
auto *shapeLoc =
CS.getConstraintLocator(expr, ConstraintLocator::PackShape);
auto *shapeTypeVar = CS.createTypeVariable(
shapeLoc, TVO_CanBindToPack | TVO_CanBindToHole);
auto expansionType = PackExpansionType::get(patternType, shapeTypeVar);
CS.setType(expr, expansionType);
}

virtual Type visitErrorExpr(ErrorExpr *E) {
Expand Down Expand Up @@ -1384,6 +1404,17 @@ namespace {
return BoundGenericStructType::get(regexDecl, Type(), {matchType});
}

PackExpansionExpr *getParentPackExpansionExpr(Expr *E) const {
auto *current = E;
while (auto *parent = CS.getParentExpr(current)) {
if (auto *expansion = dyn_cast<PackExpansionExpr>(parent)) {
return expansion;
}
current = parent;
}
return nullptr;
}

Type visitDeclRefExpr(DeclRefExpr *E) {
auto locator = CS.getConstraintLocator(E);

Expand Down Expand Up @@ -1426,13 +1457,15 @@ namespace {

// value packs cannot be referenced without `each` immediately
// preceding them.
if (auto *expansion = knownType->getAs<PackExpansionType>()) {
if (!PackElementEnvironments.empty() &&
if (auto *expansionType = knownType->getAs<PackExpansionType>()) {
if (auto *parentExpansionExpr = getParentPackExpansionExpr(E);
parentExpansionExpr &&
!isExpr<PackElementExpr>(CS.getParentExpr(E))) {
auto packType = expansion->getPatternType();
auto packType = expansionType->getPatternType();
(void)CS.recordFix(
IgnoreMissingEachKeyword::create(CS, packType, locator));
auto eltType = openPackElement(packType, locator);
auto eltType =
openPackElement(packType, locator, parentExpansionExpr);
CS.setType(E, eltType);
return eltType;
}
Expand Down Expand Up @@ -3033,21 +3066,11 @@ namespace {
assert(PackElementEnvironments.back() == expr);
PackElementEnvironments.pop_back();

auto *patternLoc =
CS.getConstraintLocator(expr, ConstraintLocator::PackExpansionPattern);
auto patternTy = CS.createTypeVariable(patternLoc,
TVO_CanBindToPack |
TVO_CanBindToNoEscape |
TVO_CanBindToHole);
auto expansionType = CS.getType(expr)->castTo<PackExpansionType>();
auto elementResultType = CS.getType(expr->getPatternExpr());
CS.addConstraint(ConstraintKind::PackElementOf, elementResultType,
patternTy, CS.getConstraintLocator(expr));

auto *shapeLoc =
CS.getConstraintLocator(expr, ConstraintLocator::PackShape);
auto *shapeTypeVar = CS.createTypeVariable(shapeLoc,
TVO_CanBindToPack |
TVO_CanBindToHole);
expansionType->getPatternType(),
CS.getConstraintLocator(expr));

// Generate ShapeOf constraints between all packs expanded by this
// pack expansion expression through the shape type variable.
Expand All @@ -3061,9 +3084,14 @@ namespace {

for (auto pack : expandedPacks) {
Type packType;
if (auto *elementExpr = getAsExpr<PackElementExpr>(pack)) {
packType = CS.getType(elementExpr->getPackRefExpr());
} else if (auto *elementType = getAsTypeRepr<PackElementTypeRepr>(pack)) {
/// Skipping over pack elements because the relationship to its
/// environment is now established during \c addPackElementEnvironment
/// upon visiting its pack expansion and the Shape constraint added
/// upon visiting the pack element.
if (isExpr<PackElementExpr>(pack)) {
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could you please leave a comment here explaining why we skip over pack elements and where the connection between this pack expansion and pack element is established now to make it easier in the future to figure out what is going on here?

} else if (auto *elementType =
getAsTypeRepr<PackElementTypeRepr>(pack)) {
// OpenPackElementType sets types for 'each T' type reprs in
// expressions. Some invalid code won't make it there, and
// the constraint system won't have recorded a type.
Expand All @@ -3076,11 +3104,11 @@ namespace {
}

CS.addConstraint(
ConstraintKind::ShapeOf, shapeTypeVar, packType,
ConstraintKind::ShapeOf, expansionType->getCountType(), packType,
CS.getConstraintLocator(expr, ConstraintLocator::PackShape));
}

return PackExpansionType::get(patternTy, shapeTypeVar);
return expansionType;
}

Type visitPackElementExpr(PackElementExpr *expr) {
Expand All @@ -3094,7 +3122,18 @@ namespace {
CS.setType(expr->getPackRefExpr(), packType);
}

return openPackElement(packType, CS.getConstraintLocator(expr));
auto *packEnvironment = CS.getPackEnvironment(expr);
if (packEnvironment) {
auto expansionType =
CS.getType(packEnvironment)->castTo<PackExpansionType>();
CS.addConstraint(ConstraintKind::ShapeOf, expansionType->getCountType(),
packType,
CS.getConstraintLocator(packEnvironment,
ConstraintLocator::PackShape));
}

return openPackElement(packType, CS.getConstraintLocator(expr),
packEnvironment);
}

Type visitMaterializePackExpr(MaterializePackExpr *expr) {
Expand Down
11 changes: 11 additions & 0 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ Solution ConstraintSystem::finalize() {
solution.PackExpansionEnvironments.insert(env);
}

solution.PackEnvironments = PackEnvironments;

return solution;
}

Expand Down Expand Up @@ -290,6 +292,11 @@ void ConstraintSystem::applySolution(const Solution &solution) {
PackExpansionEnvironments.insert(expansion);
}

// Register the solutions's pack environments.
for (auto &packEnvironment : solution.PackEnvironments) {
PackEnvironments.insert(packEnvironment);
}

// Register the defaulted type variables.
DefaultedConstraints.insert(solution.DefaultedConstraints.begin(),
solution.DefaultedConstraints.end());
Expand Down Expand Up @@ -613,6 +620,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numOpenedExistentialTypes = cs.OpenedExistentialTypes.size();
numOpenedPackExpansionTypes = cs.OpenedPackExpansionTypes.size();
numPackExpansionEnvironments = cs.PackExpansionEnvironments.size();
numPackEnvironments = cs.PackEnvironments.size();
numDefaultedConstraints = cs.DefaultedConstraints.size();
numAddedNodeTypes = cs.addedNodeTypes.size();
numAddedKeyPathComponentTypes = cs.addedKeyPathComponentTypes.size();
Expand Down Expand Up @@ -697,6 +705,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
// Remove any pack expansion environments.
truncate(cs.PackExpansionEnvironments, numPackExpansionEnvironments);

// Remove any pack environments.
truncate(cs.PackEnvironments, numPackEnvironments);

// Remove any defaulted type variables.
truncate(cs.DefaultedConstraints, numDefaultedConstraints);

Expand Down
15 changes: 15 additions & 0 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,21 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator,
shapeParam, contextSubs);
}

PackExpansionExpr *
ConstraintSystem::getPackEnvironment(PackElementExpr *packElement) const {
const auto match = PackEnvironments.find(packElement);
return (match == PackEnvironments.end()) ? nullptr : match->second;
}

void ConstraintSystem::addPackEnvironment(PackElementExpr *packElement,
PackExpansionExpr *packExpansion) {
assert(packElement);
assert(packExpansion);
[[maybe_unused]] const auto inserted =
PackEnvironments.insert({packElement, packExpansion}).second;
assert(inserted && "Mapping already defined?");
}

/// Extend the given depth map by adding depths for all of the subexpressions
/// of the given expression.
static void extendDepthMap(
Expand Down
15 changes: 14 additions & 1 deletion test/Constraints/pack-expansion-expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func typeReprPacks<each T: ExpressibleByIntegerLiteral>(_ t: repeat each T) {
func sameShapeDiagnostics<each T, each U>(t: repeat each T, u: repeat each U) {
_ = (repeat (each t, each u)) // expected-error {{pack expansion requires that 'each T' and 'each U' have the same shape}}
_ = (repeat Array<(each T, each U)>()) // expected-error {{pack expansion requires that 'each T' and 'each U' have the same shape}}
_ = (repeat (Array<each T>(), each u)) // expected-error {{pack expansion requires that 'each T' and 'each U' have the same shape}}
_ = (repeat (Array<each T>(), each u)) // expected-error {{pack expansion requires that 'each U' and 'each T' have the same shape}}
}

func returnPackExpansionType<each T>(_ t: repeat each T) -> repeat each T { // expected-error {{pack expansion 'repeat each T' can only appear in a function parameter list, tuple element, or generic argument list}}
Expand Down Expand Up @@ -604,3 +604,16 @@ func test_that_expansions_are_bound_early() {
})
}
}

do {
func test<T>(x: T) {}

// rdar://110711746 to make this valid
func caller1<each T>(x: repeat each T) {
_ = (repeat { test(x: each x) }()) // expected-error {{pack reference 'each T' can only appear in pack expansion}}
}

func caller2<each T>(x: repeat each T) {
_ = { (repeat test(x: each x)) }()
}
}