Skip to content

Commit 882e1e7

Browse files
committed
[Requirement Machine] Implement same-element requirements.
1 parent b330ac5 commit 882e1e7

22 files changed

+160
-32
lines changed

lib/AST/ASTPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,9 @@ void PrintAST::printRequirement(const Requirement &req) {
18561856
SmallVector<Type, 2> rootParameterPacks;
18571857
getTransformedType(req.getFirstType())
18581858
->getTypeParameterPacks(rootParameterPacks);
1859+
if (req.getKind() != RequirementKind::Layout)
1860+
getTransformedType(req.getSecondType())
1861+
->getTypeParameterPacks(rootParameterPacks);
18591862
bool isPackRequirement = !rootParameterPacks.empty();
18601863

18611864
switch (req.getKind()) {

lib/AST/GenericEnvironment.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,16 @@ struct FindElementArchetypeForOpenedPackParam {
180180
: findElementParam(env, openedPacks), getElementArchetype(env) {}
181181

182182

183-
ElementArchetypeType *operator()(Type interfaceType) {
183+
Type operator()(Type interfaceType) {
184184
assert(interfaceType->isTypeParameter());
185185
if (auto member = interfaceType->getAs<DependentMemberType>()) {
186-
auto baseArchetype = (*this)(member->getBase());
186+
auto baseArchetype = (*this)(member->getBase())
187+
->castTo<ElementArchetypeType>();
187188
return baseArchetype->getNestedType(member->getAssocType())
188189
->castTo<ElementArchetypeType>();
189190
}
190191
assert(interfaceType->is<GenericTypeParamType>());
191-
return getElementArchetype(findElementParam(interfaceType))
192-
->castTo<ElementArchetypeType>();
192+
return getElementArchetype(findElementParam(interfaceType));
193193
}
194194
};
195195

lib/AST/GenericSignature.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,10 @@ int swift::compareDependentTypes(Type type1, Type type2) {
794794
// Fast-path check for equality.
795795
if (type1->isEqual(type2)) return 0;
796796

797+
// Packs are always ordered after scalar type parameters.
798+
if (type1->isParameterPack() != type2->isParameterPack())
799+
return type2->isParameterPack() ? -1 : +1;
800+
797801
// Ordering is as follows:
798802
// - Generic params
799803
auto gp1 = type1->getAs<GenericTypeParamType>();

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ RequirementMachine::getLongestValidPrefix(const MutableTerm &term) const {
278278
case Symbol::Kind::ConcreteType:
279279
case Symbol::Kind::ConcreteConformance:
280280
case Symbol::Kind::Shape:
281+
case Symbol::Kind::PackElement:
281282
llvm::errs() <<"Invalid symbol in a type term: " << term << "\n";
282283
abort();
283284
}
@@ -307,6 +308,20 @@ bool RequirementMachine::isReducedType(Type type) const {
307308
if (!component->hasTypeParameter())
308309
return Action::SkipChildren;
309310

311+
if (auto *expansion = component->getAs<PackExpansionType>()) {
312+
auto pattern = expansion->getPatternType();
313+
auto shape = expansion->getCountType();
314+
if (!Self.isReducedType(pattern))
315+
return Action::Stop;
316+
317+
auto reducedShape =
318+
Self.getReducedShape(shape, Self.getGenericParams());
319+
if (reducedShape->getCanonicalType() != CanType(shape))
320+
return Action::Stop;
321+
322+
return Action::SkipChildren;
323+
}
324+
310325
if (!component->isTypeParameter())
311326
return Action::Continue;
312327

@@ -788,6 +803,7 @@ void RequirementMachine::verify(const MutableTerm &term) const {
788803
switch (symbol.getKind()) {
789804
case Symbol::Kind::Protocol:
790805
case Symbol::Kind::GenericParam:
806+
case Symbol::Kind::PackElement:
791807
erased.add(symbol);
792808
continue;
793809

@@ -827,6 +843,7 @@ void RequirementMachine::verify(const MutableTerm &term) const {
827843
case Symbol::Kind::Superclass:
828844
case Symbol::Kind::ConcreteType:
829845
case Symbol::Kind::ConcreteConformance:
846+
case Symbol::Kind::PackElement:
830847
llvm::errs() << "Bad interior symbol " << symbol << " in " << term << "\n";
831848
abort();
832849
break;

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,8 @@ RewriteSystem::getMinimizedGenericSignatureRules() const {
743743
continue;
744744
}
745745

746-
if (rule.getLHS()[0].getKind() != Symbol::Kind::GenericParam)
746+
if (rule.getLHS()[0].getKind() != Symbol::Kind::PackElement &&
747+
rule.getLHS()[0].getKind() != Symbol::Kind::GenericParam)
747748
continue;
748749

749750
rules.push_back(ruleID);

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end,
298298
// member type rooted at Self; handle the associated type below.
299299
break;
300300

301+
case Symbol::Kind::PackElement:
302+
continue;
303+
301304
case Symbol::Kind::Name:
302305
case Symbol::Kind::Layout:
303306
case Symbol::Kind::Superclass:

lib/AST/RequirementMachine/MinimalConformances.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ static const ProtocolDecl *getParentConformanceForTerm(Term lhs) {
286286
case Symbol::Kind::ConcreteType:
287287
case Symbol::Kind::ConcreteConformance:
288288
case Symbol::Kind::Shape:
289+
case Symbol::Kind::PackElement:
289290
break;
290291
}
291292

@@ -552,6 +553,7 @@ void RewriteSystem::computeCandidateConformancePaths(
552553
//
553554
// where Y is the simplified form of X.W.
554555
} else if (rhs.isAnyConformanceRule() &&
556+
!lhs.isSameElementRule() &&
555557
(unsigned)(lhs.getLHS().end() - from) < rhs.getLHS().size()) {
556558
if (Debug.contains(DebugFlags::MinimalConformancesDetail)) {
557559
llvm::dbgs() << "Case 2: same-type suffix\n";

lib/AST/RequirementMachine/PropertyUnification.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ void PropertyMap::addProperty(
646646
case Symbol::Kind::GenericParam:
647647
case Symbol::Kind::AssociatedType:
648648
case Symbol::Kind::Shape:
649+
case Symbol::Kind::PackElement:
649650
break;
650651
}
651652

lib/AST/RequirementMachine/RequirementBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
248248
case Symbol::Kind::AssociatedType:
249249
case Symbol::Kind::GenericParam:
250250
case Symbol::Kind::Shape:
251+
case Symbol::Kind::PackElement:
251252
break;
252253
}
253254

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,6 @@ static void desugarSameTypeRequirement(Requirement req, SourceLoc loc,
218218
break;
219219
}
220220

221-
// If one side is a parameter pack, this is a same-element requirement, which
222-
// is not yet supported.
223-
if (firstType->isParameterPack() != secondType->isParameterPack()) {
224-
errors.push_back(RequirementError::forSameElement(
225-
{kind, sugaredFirstType, secondType}, loc));
226-
recordedErrors = true;
227-
return true;
228-
}
229-
230221
if (firstType->isTypeParameter() && secondType->isTypeParameter()) {
231222
result.emplace_back(kind, sugaredFirstType, secondType);
232223
return true;

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ void RequirementMachine::dump(llvm::raw_ostream &out) const {
560560
for (auto paramTy : Params) {
561561
out << " " << Type(paramTy);
562562
if (paramTy->isParameterPack())
563-
out << "";
563+
out << " " << paramTy;
564564
}
565565
out << " >";
566566
}

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,13 +963,13 @@ InferredGenericSignatureRequest::evaluate(
963963
if (reduced->hasError() || reduced->isEqual(genericParam))
964964
continue;
965965

966-
if (reduced->isTypeParameter()) {
967-
// If one side is a parameter pack and the other is not, this is a
968-
// same-element requirement that cannot be expressed with only one
969-
// type parameter.
970-
if (genericParam->isParameterPack() != reduced->isParameterPack())
971-
continue;
966+
// If one side is a parameter pack and the other is not, this is a
967+
// same-element requirement that cannot be expressed with only one
968+
// type parameter.
969+
if (genericParam->isParameterPack() != reduced->isParameterPack())
970+
continue;
972971

972+
if (reduced->isTypeParameter()) {
973973
ctx.Diags.diagnose(loc, diag::requires_generic_params_made_equal,
974974
genericParam, result->getSugaredType(reduced))
975975
.warnUntilSwiftVersion(6);

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class RewriteContext final {
4848
/// The singleton storage for shape symbols.
4949
Symbol::Storage *TheShapeSymbol;
5050

51+
/// The singleton storage for pack element symbols.
52+
Symbol::Storage *ThePackElementSymbol;
53+
5154
/// Folding set for uniquing terms.
5255
llvm::FoldingSet<Term::Storage> Terms;
5356

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,8 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
569569
}
570570

571571
if (index != 0) {
572-
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam);
572+
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam ||
573+
lhs[index - 1].getKind() == Symbol::Kind::PackElement);
573574
}
574575

575576
if (!rule.isLHSSimplified() &&

lib/AST/RequirementMachine/Rule.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ const ProtocolDecl *Rule::isAnyConformanceRule() const {
7373
case Symbol::Kind::AssociatedType:
7474
case Symbol::Kind::GenericParam:
7575
case Symbol::Kind::Shape:
76+
case Symbol::Kind::PackElement:
7677
break;
7778
}
7879

@@ -146,6 +147,14 @@ bool Rule::isCircularConformanceRule() const {
146147
return true;
147148
}
148149

150+
/// Returns \c true if this rule is prefixed with the \c [element] symbol.
151+
bool Rule::isSameElementRule() const {
152+
if (LHS.size() == 0)
153+
return false;
154+
155+
return LHS[0].getKind() == Symbol::Kind::PackElement;
156+
}
157+
149158
/// A protocol typealias rule takes one of the following two forms,
150159
/// where T is a name symbol:
151160
///

lib/AST/RequirementMachine/Rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class Rule final {
137137

138138
bool isCircularConformanceRule() const;
139139

140+
bool isSameElementRule() const;
141+
140142
/// See above for an explanation of these predicates.
141143
bool isPermanent() const {
142144
return Permanent;

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,21 @@ void RuleBuilder::addRequirement(const Requirement &req,
390390
otherType, *substitutions)
391391
: Context.getMutableTermForType(
392392
otherType, proto));
393+
394+
if (subjectType->isParameterPack() != otherType->isParameterPack()) {
395+
// This is a same-element requirement.
396+
auto elementSymbol = Symbol::forPackElement(Context);
397+
llvm::SmallVector<Symbol, 3> symbols{elementSymbol};
398+
399+
if (subjectType->isParameterPack()) {
400+
symbols.append(subjectTerm.begin(), subjectTerm.end());
401+
subjectTerm = MutableTerm(std::move(symbols));
402+
} else {
403+
symbols.append(constraintTerm.begin(), constraintTerm.end());
404+
constraintTerm = MutableTerm(std::move(symbols));
405+
}
406+
}
407+
393408
break;
394409
}
395410
}

lib/AST/RequirementMachine/Symbol.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ struct Symbol::Storage final
8080
Kind = Kind::Shape;
8181
}
8282

83+
/// A dummy type for overload resolution of the
84+
/// 'pack element' constructor for Storage.
85+
struct ForPackElement {};
86+
87+
explicit Storage(ForPackElement shape) {
88+
Kind = Kind::PackElement;
89+
}
90+
8391
Storage(const ProtocolDecl *proto, Identifier name) {
8492
Kind = Symbol::Kind::AssociatedType;
8593
Proto = proto;
@@ -316,6 +324,16 @@ Symbol Symbol::forShape(RewriteContext &ctx) {
316324
return (ctx.TheShapeSymbol = symbol);
317325
}
318326

327+
Symbol Symbol::forPackElement(RewriteContext &ctx) {
328+
if (auto *symbol = ctx.ThePackElementSymbol)
329+
return symbol;
330+
331+
unsigned size = Storage::totalSizeToAlloc<unsigned, Term>(0, 0);
332+
void *mem = ctx.Allocator.Allocate(size, alignof(Storage));
333+
auto *symbol = new (mem) Storage(Storage::ForPackElement());
334+
return (ctx.ThePackElementSymbol = symbol);
335+
}
336+
319337
/// Creates a layout symbol, representing a layout constraint.
320338
Symbol Symbol::forLayout(LayoutConstraint layout,
321339
RewriteContext &ctx) {
@@ -455,6 +473,7 @@ const ProtocolDecl *Symbol::getRootProtocol() const {
455473
return getProtocol();
456474

457475
case Symbol::Kind::GenericParam:
476+
case Symbol::Kind::PackElement:
458477
return nullptr;
459478

460479
case Symbol::Kind::Name:
@@ -479,6 +498,8 @@ const ProtocolDecl *Symbol::getRootProtocol() const {
479498
/// - AssociatedType
480499
/// - GenericParam
481500
/// - Name
501+
/// - Shape
502+
/// - PackElement
482503
/// - Layout
483504
/// - Superclass
484505
/// - ConcreteType
@@ -535,6 +556,7 @@ llvm::Optional<int> Symbol::compare(Symbol other, RewriteContext &ctx) const {
535556
}
536557

537558
case Kind::Shape:
559+
case Kind::PackElement:
538560
case Kind::GenericParam: {
539561
auto *param = getGenericParam();
540562
auto *otherParam = other.getGenericParam();
@@ -617,6 +639,7 @@ Symbol Symbol::withConcreteSubstitutions(
617639
case Kind::Protocol:
618640
case Kind::AssociatedType:
619641
case Kind::Shape:
642+
case Kind::PackElement:
620643
case Kind::Layout:
621644
break;
622645
}
@@ -734,6 +757,11 @@ void Symbol::dump(llvm::raw_ostream &out) const {
734757
out << "[shape]";
735758
return;
736759
}
760+
761+
case Kind::PackElement: {
762+
out << "[element]";
763+
return;
764+
}
737765
}
738766

739767
llvm_unreachable("Bad symbol kind");
@@ -756,6 +784,7 @@ void Symbol::Storage::Profile(llvm::FoldingSetNodeID &id) const {
756784
return;
757785

758786
case Symbol::Kind::GenericParam:
787+
case Symbol::Kind::PackElement:
759788
id.AddPointer(GenericParam);
760789
return;
761790

lib/AST/RequirementMachine/Symbol.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class Symbol final {
117117
/// generic parameter.
118118
Shape,
119119

120+
/// A pack element [element].(each T) where 'each T' is a type
121+
/// parameter pack.
122+
PackElement,
123+
120124
//////
121125
////// "Property-like" symbol kinds:
122126
//////
@@ -206,6 +210,8 @@ class Symbol final {
206210

207211
static Symbol forShape(RewriteContext &ctx);
208212

213+
static Symbol forPackElement(RewriteContext &Ctx);
214+
209215
static Symbol forLayout(LayoutConstraint layout,
210216
RewriteContext &ctx);
211217

lib/AST/RequirementMachine/Term.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,32 @@ static llvm::Optional<int> shortlexCompare(const Symbol *lhsBegin,
133133
const Symbol *rhsBegin,
134134
const Symbol *rhsEnd,
135135
RewriteContext &ctx) {
136-
// First, compare the number of name symbols.
136+
// First, compare the number of name and pack element symbols.
137137
unsigned lhsNameCount = 0;
138+
unsigned lhsPackElementCount = 0;
138139
for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
139140
if (iter->getKind() == Symbol::Kind::Name)
140141
++lhsNameCount;
142+
143+
if (iter->getKind() == Symbol::Kind::PackElement)
144+
++lhsPackElementCount;
141145
}
142146

143147
unsigned rhsNameCount = 0;
148+
unsigned rhsPackElementCount = 0;
144149
for (auto *iter = rhsBegin; iter != rhsEnd; ++iter) {
145150
if (iter->getKind() == Symbol::Kind::Name)
146151
++rhsNameCount;
152+
153+
if (iter->getKind() == Symbol::Kind::PackElement)
154+
++rhsPackElementCount;
147155
}
148156

157+
// A term with more pack element symbols orders after a term with
158+
// fewer pack element symbols.
159+
if (lhsPackElementCount != rhsPackElementCount)
160+
return lhsPackElementCount > rhsPackElementCount ? 1 : -1;
161+
149162
// A term with more name symbols orders after a term with fewer name symbols.
150163
if (lhsNameCount != rhsNameCount)
151164
return lhsNameCount > rhsNameCount ? 1 : -1;

0 commit comments

Comments
 (0)