Skip to content

Fix pack expansion matching #61610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -5197,6 +5197,12 @@ class ConstraintSystem {
ConstraintKind kind, TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

TypeMatchResult
matchPackExpansionTypes(PackExpansionType *expansion1,
PackExpansionType *expansion2,
ConstraintKind kind, TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Subroutine of \c matchTypes(), which matches up two tuple types.
///
/// \returns the result of performing the tuple-to-tuple conversion.
Expand Down
56 changes: 30 additions & 26 deletions lib/AST/PackExpansionMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ bool TuplePackMatcher::match() {
// A pack expansion type on the left hand side absorbs all elements
// from the right hand side up to the next mismatched label.
auto lhsElt = lhsElts.front();
if (lhsElt.getType()->is<PackExpansionType>()) {
if (auto *lhsExpansionType = lhsElt.getType()->getAs<PackExpansionType>()) {
lhsElts = lhsElts.slice(1);

assert(lhsElts.empty() || lhsElts.front().hasName() &&
"Tuple element with pack expansion type cannot be followed "
"by an unlabeled element");

auto *rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
pairs.emplace_back(lhsElt.getType(), rhs, idx++);
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, idx++);
continue;
}

Expand All @@ -82,15 +82,15 @@ bool TuplePackMatcher::match() {
// A pack expansion type on the right hand side absorbs all elements
// from the left hand side up to the next mismatched label.
auto rhsElt = rhsElts.front();
if (rhsElt.getType()->is<PackExpansionType>()) {
if (auto *rhsExpansionType = rhsElt.getType()->getAs<PackExpansionType>()) {
rhsElts = rhsElts.slice(1);

assert(rhsElts.empty() || rhsElts.front().hasName() &&
"Tuple element with pack expansion type cannot be followed "
"by an unlabeled element");

auto *lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
pairs.emplace_back(lhs, rhsElt.getType(), idx++);
pairs.emplace_back(lhs, rhsExpansionType->getPatternType(), idx++);
continue;
}

Expand Down Expand Up @@ -169,34 +169,38 @@ bool ParamPackMatcher::match() {

// If the left hand side is a single pack expansion type, bind it
// to what remains of the right hand side.
if (lhsParams.size() == 1 &&
lhsParams[0].getPlainType()->is<PackExpansionType>()) {
SmallVector<Type, 2> rhsTypes;
for (auto rhsParam : rhsParams) {
// FIXME: Check rhs flags
rhsTypes.push_back(rhsParam.getPlainType());
}
auto rhs = PackType::get(ctx, rhsTypes);
if (lhsParams.size() == 1) {
auto lhsType = lhsParams[0].getPlainType();
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
SmallVector<Type, 2> rhsTypes;
for (auto rhsParam : rhsParams) {
// FIXME: Check rhs flags
rhsTypes.push_back(rhsParam.getPlainType());
}
auto rhs = PackType::get(ctx, rhsTypes);

// FIXME: Check lhs flags
pairs.emplace_back(lhsParams[0].getPlainType(), rhs, prefixLength);
return false;
// FIXME: Check lhs flags
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
return false;
}
}

// If the right hand side is a single pack expansion type, bind it
// to what remains of the left hand side.
if (rhsParams.size() == 1 &&
rhsParams[0].getPlainType()->is<PackExpansionType>()) {
SmallVector<Type, 2> lhsTypes;
for (auto lhsParam : lhsParams) {
// FIXME: Check lhs flags
lhsTypes.push_back(lhsParam.getPlainType());
}
auto lhs = PackType::get(ctx, lhsTypes);
if (rhsParams.size() == 1) {
auto rhsType = rhsParams[0].getPlainType();
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
SmallVector<Type, 2> lhsTypes;
for (auto lhsParam : lhsParams) {
// FIXME: Check lhs flags
lhsTypes.push_back(lhsParam.getPlainType());
}
auto lhs = PackType::get(ctx, lhsTypes);

// FIXME: Check rhs flags
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
return false;
// FIXME: Check rhs flags
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
return false;
}
}

// Otherwise, all remaining possibilities are invalid:
Expand Down
42 changes: 24 additions & 18 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,23 @@ ConstraintSystem::matchPackTypes(PackType *pack1, PackType *pack2,
return getTypeMatchSuccess();
}

ConstraintSystem::TypeMatchResult
ConstraintSystem::matchPackExpansionTypes(PackExpansionType *expansion1,
PackExpansionType *expansion2,
ConstraintKind kind, TypeMatchOptions flags,
ConstraintLocatorBuilder locator) {
// FIXME: Should we downgrade kind to Bind or something here?
auto result = matchTypes(expansion1->getCountType(),
expansion2->getCountType(),
kind, flags, locator);
if (result.isFailure())
return result;

return matchTypes(expansion1->getPatternType(),
expansion2->getPatternType(),
kind, flags, locator);
}

/// Check where a representation is a subtype of another.
///
/// The subtype relationship is defined as:
Expand Down Expand Up @@ -6636,10 +6653,13 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
kind, subflags, packLoc);
}
case TypeKind::PackExpansion: {
// FIXME: we need to match the count types as well
return matchTypes(cast<PackExpansionType>(desugar1)->getPatternType(),
cast<PackExpansionType>(desugar2)->getPatternType(),
kind, subflags, locator);
// FIXME: Need a new locator element

auto expansion1 = cast<PackExpansionType>(desugar1);
auto expansion2 = cast<PackExpansionType>(desugar2);

return matchPackExpansionTypes(expansion1, expansion2, kind, subflags,
locator);
}
}
}
Expand Down Expand Up @@ -7034,20 +7054,6 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
}
}

if (isa<PackExpansionType>(desugar1) && isa<PackType>(desugar2)) {
auto *packExpansionType = cast<PackExpansionType>(desugar1);
auto *packType = cast<PackType>(desugar2);

if (packExpansionType->getPatternType()->is<TypeVariableType>())
return matchTypes(packExpansionType->getPatternType(), packType, kind, subflags, locator);
} else if (isa<PackType>(desugar1) && isa<PackExpansionType>(desugar2)) {
auto *packType = cast<PackType>(desugar1);
auto *packExpansionType = cast<PackExpansionType>(desugar2);

if (packExpansionType->getPatternType()->is<TypeVariableType>())
return matchTypes(packType, packExpansionType->getPatternType(), kind, subflags, locator);
}

// Attempt fixes iff it's allowed, both types are concrete and
// we are not in the middle of attempting one already.
if (shouldAttemptFixes() && !flags.contains(TMF_ApplyingFix)) {
Expand Down