Skip to content

Commit 4a49419

Browse files
committed
RequirementMachine: Same-type requirements imply same-shape requirements
We want `T.A == U.B` to imply `shape(T) == shape(U)` if T (and thus U) is a parameter pack. To do this, we introduce some new rewrite rules: 1) For each associated type symbol `[P:A]`, a rule `([P:A].[shape] => [P:A])`. 2) For each non-pack generic parameter `τ_d_i`, a rule `τ_d_i.[shape] => [shape]`. Now consider a rewrite rule `(τ_d_i.[P:A] => τ_D_I.[Q:B])`. The left-hand side overlaps with the rule `([P:A].[shape] => [shape])` on the term `τ_d_i.[P:A].[shape]`. Resolving the overlap gives us a new rule t_d_i.[shape] => T_D_I.[shape] If T is a term corresponding to some type parameter, we say that `T.[shape]` is a shape term. If `T'.[shape]` is a reduced term, we say that T' is the reduced shape of T. Recall that shape requirements are represented as rules of the form: τ_d_i.[shape] => τ_D_I.[shape] Now, the rules of the first kind reduce our shape term `T.[shape]` to `τ_d_i.[shape]`, where `τ_d_i` is the root generic parameter of T. If `τ_d_i` is not a pack, a rule of the second kind reduces it to `[shape]`, so the reduced shape of a non-pack parameter T is the empty term. Otherwise, if `τ_d_i` is a pack, `τ_d_i.[shape]` might reduce to `τ_D_I.[shape]` via a shape requirement. In this case, `τ_D_I` is the reduced shape of T. Fixes rdar://problem/101813873.
1 parent 89ad597 commit 4a49419

File tree

6 files changed

+128
-25
lines changed

6 files changed

+128
-25
lines changed

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ RequirementMachine::initWithProtocolSignatureRequirements(
307307
///
308308
/// Returns failure if completion fails within the configured number of steps.
309309
std::pair<CompletionResult, unsigned>
310-
RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
310+
RequirementMachine::initWithGenericSignature(GenericSignature sig) {
311311
Sig = sig;
312312
Params.append(sig.getGenericParams().begin(),
313313
sig.getGenericParams().end());
@@ -323,7 +323,8 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
323323
// Collect the top-level requirements, and all transitively-referenced
324324
// protocol requirement signatures.
325325
RuleBuilder builder(Context, System.getReferencedProtocols());
326-
builder.initWithGenericSignatureRequirements(sig.getRequirements());
326+
builder.initWithGenericSignature(sig.getGenericParams(),
327+
sig.getRequirements());
327328

328329
// Add the initial set of rewrite rules to the rewrite system.
329330
System.initialize(/*recordLoops=*/false,
@@ -425,7 +426,7 @@ RequirementMachine::initWithWrittenRequirements(
425426
// Collect the top-level requirements, and all transitively-referenced
426427
// protocol requirement signatures.
427428
RuleBuilder builder(Context, System.getReferencedProtocols());
428-
builder.initWithWrittenRequirements(requirements);
429+
builder.initWithWrittenRequirements(genericParams, requirements);
429430

430431
// Add the initial set of rewrite rules to the rewrite system.
431432
System.initialize(/*recordLoops=*/true,

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class RequirementMachine final {
5454
friend class swift::AbstractGenericSignatureRequest;
5555
friend class swift::InferredGenericSignatureRequest;
5656

57-
CanGenericSignature Sig;
57+
GenericSignature Sig;
5858
SmallVector<GenericTypeParamType *, 2> Params;
5959

6060
RewriteContext &Context;
@@ -95,7 +95,7 @@ class RequirementMachine final {
9595
ArrayRef<const ProtocolDecl *> protos);
9696

9797
std::pair<CompletionResult, unsigned>
98-
initWithGenericSignature(CanGenericSignature sig);
98+
initWithGenericSignature(GenericSignature sig);
9999

100100
std::pair<CompletionResult, unsigned>
101101
initWithProtocolWrittenRequirements(

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,16 @@ void RewriteSystem::recordRewriteLoop(MutableTerm basepoint,
504504
return;
505505

506506
// Ignore the rewrite loop if it is not part of our minimization domain.
507-
if (!isInMinimizationDomain(basepoint.getRootProtocol()))
507+
//
508+
// Completion might record a rewrite loop where the basepoint is just
509+
// the term [shape]. In this case though, we know it's in our domain,
510+
// since completion only checks local rules for overlap. Other callers
511+
// of recordRewriteLoop() always pass in a valid basepoint, so we
512+
// check.
513+
if (basepoint[0].getKind() != Symbol::Kind::Shape &&
514+
!isInMinimizationDomain(basepoint.getRootProtocol())) {
508515
return;
516+
}
509517

510518
Loops.push_back(loop);
511519
}
@@ -555,11 +563,6 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
555563
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
556564
}
557565

558-
// A shape symbol must follow a generic param symbol
559-
if (symbol.getKind() == Symbol::Kind::Shape) {
560-
ASSERT_RULE(index > 0 && lhs[index - 1].getKind() == Symbol::Kind::GenericParam);
561-
}
562-
563566
if (!rule.isLHSSimplified() &&
564567
index != lhs.size() - 1) {
565568
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteConformance);
@@ -602,15 +605,10 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
602605
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Superclass);
603606
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteType);
604607

605-
if (index != lhs.size() - 1) {
608+
if (index != rhs.size() - 1) {
606609
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
607610
}
608611

609-
// A shape symbol must follow a generic param symbol
610-
if (symbol.getKind() == Symbol::Kind::Shape) {
611-
ASSERT_RULE(index > 0 && rhs[index - 1].getKind() == Symbol::Kind::GenericParam);
612-
}
613-
614612
// Completion can introduce a rule of the form
615613
//
616614
// (T.[P] => T.[concrete: C : P])
@@ -635,10 +633,15 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
635633
}
636634
}
637635

638-
auto lhsDomain = lhs.getRootProtocol();
639-
auto rhsDomain = rhs.getRootProtocol();
640-
641-
ASSERT_RULE(lhsDomain == rhsDomain);
636+
if (rhs.size() == 1 && rhs[0].getKind() == Symbol::Kind::Shape) {
637+
// We can have a rule like T.[shape] => [shape].
638+
ASSERT_RULE(lhs.back().getKind() == Symbol::Kind::Shape);
639+
} else {
640+
// Otherwise, LHS and RHS must have the same domain.
641+
auto lhsDomain = lhs.getRootProtocol();
642+
auto rhsDomain = rhs.getRootProtocol();
643+
ASSERT_RULE(lhsDomain == rhsDomain);
644+
}
642645
}
643646

644647
#undef ASSERT_RULE

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ using namespace rewriting;
3434

3535
/// For building a rewrite system for a generic signature from canonical
3636
/// requirements.
37-
void RuleBuilder::initWithGenericSignatureRequirements(
37+
void RuleBuilder::initWithGenericSignature(
38+
ArrayRef<GenericTypeParamType *> genericParams,
3839
ArrayRef<Requirement> requirements) {
3940
assert(!Initialized);
4041
Initialized = 1;
@@ -47,6 +48,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
4748
}
4849

4950
collectRulesFromReferencedProtocols();
51+
collectPackShapeRules(genericParams);
5052

5153
// Add rewrite rules for all top-level requirements.
5254
for (const auto &req : requirements)
@@ -56,6 +58,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
5658
/// For building a rewrite system for a generic signature from user-written
5759
/// requirements.
5860
void RuleBuilder::initWithWrittenRequirements(
61+
ArrayRef<GenericTypeParamType *> genericParams,
5962
ArrayRef<StructuralRequirement> requirements) {
6063
assert(!Initialized);
6164
Initialized = 1;
@@ -68,6 +71,7 @@ void RuleBuilder::initWithWrittenRequirements(
6871
}
6972

7073
collectRulesFromReferencedProtocols();
74+
collectPackShapeRules(genericParams);
7175

7276
// Add rewrite rules for all top-level requirements.
7377
for (const auto &req : requirements)
@@ -488,3 +492,77 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
488492
localRules.end());
489493
}
490494
}
495+
496+
void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams) {
497+
if (Dump) {
498+
llvm::dbgs() << "adding shape rules\n";
499+
}
500+
501+
if (!llvm::any_of(genericParams,
502+
[](GenericTypeParamType *t) {
503+
return t->isParameterPack();
504+
})) {
505+
return;
506+
}
507+
508+
// Each non-pack generic parameter is part of the "scalar shape class", represented
509+
// by the empty term.
510+
for (auto *genericParam : genericParams) {
511+
if (genericParam->isParameterPack())
512+
continue;
513+
514+
// Add the rule (τ_d_i.[shape] => [shape]).
515+
MutableTerm lhs;
516+
lhs.add(Symbol::forGenericParam(
517+
cast<GenericTypeParamType>(genericParam->getCanonicalType()), Context));
518+
lhs.add(Symbol::forShape(Context));
519+
520+
MutableTerm rhs;
521+
rhs.add(Symbol::forShape(Context));
522+
523+
PermanentRules.emplace_back(lhs, rhs);
524+
}
525+
526+
// A member type T.[P:A] is part of the same shape class as its base type T.
527+
llvm::DenseSet<Symbol> visited;
528+
529+
auto addMemberShapeRule = [&](const ProtocolDecl *proto, AssociatedTypeDecl *assocType) {
530+
auto symbol = Symbol::forAssociatedType(proto, assocType->getName(), Context);
531+
if (!visited.insert(symbol).second)
532+
return;
533+
534+
// Add the rule ([P:A].[shape] => [shape]).
535+
MutableTerm lhs;
536+
lhs.add(symbol);
537+
lhs.add(Symbol::forShape(Context));
538+
539+
MutableTerm rhs;
540+
rhs.add(Symbol::forShape(Context));
541+
542+
// Consider it an imported rule, since it is not part of our minimization
543+
// domain. It would be more logical if we added these in the protocol component
544+
// machine for this protocol, but instead we add them in the "leaf" generic
545+
// signature machine. This avoids polluting machines that do not involve
546+
// parameter packs with these extra rules, which would otherwise just slow
547+
// things down.
548+
Rule rule(Term::get(lhs, Context), Term::get(rhs, Context));
549+
rule.markPermanent();
550+
ImportedRules.push_back(rule);
551+
};
552+
553+
for (auto *proto : ProtocolsToImport) {
554+
if (Dump) {
555+
llvm::dbgs() << "adding member shape rules for protocol " << proto->getName() << "\n";
556+
}
557+
558+
for (auto *assocType : proto->getAssociatedTypeMembers()) {
559+
addMemberShapeRule(proto, assocType);
560+
}
561+
562+
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
563+
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
564+
addMemberShapeRule(proto, assocType);
565+
}
566+
}
567+
}
568+
}

lib/AST/RequirementMachine/RuleBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ struct RuleBuilder {
9595
Initialized = 0;
9696
}
9797

98-
void initWithGenericSignatureRequirements(ArrayRef<Requirement> requirements);
99-
void initWithWrittenRequirements(ArrayRef<StructuralRequirement> requirements);
98+
void initWithGenericSignature(ArrayRef<GenericTypeParamType *> genericParams,
99+
ArrayRef<Requirement> requirements);
100+
void initWithWrittenRequirements(ArrayRef<GenericTypeParamType *> genericParams,
101+
ArrayRef<StructuralRequirement> requirements);
100102
void initWithProtocolSignatureRequirements(ArrayRef<const ProtocolDecl *> proto);
101103
void initWithProtocolWrittenRequirements(
102104
ArrayRef<const ProtocolDecl *> component,
@@ -106,6 +108,7 @@ struct RuleBuilder {
106108
ArrayRef<Term> substitutions);
107109
void addReferencedProtocol(const ProtocolDecl *proto);
108110
void collectRulesFromReferencedProtocols();
111+
void collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams);
109112

110113
private:
111114
void addPermanentProtocolRules(const ProtocolDecl *proto);

test/Generics/pack-shape-requirements.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -disable-availability-checking 2>&1 | %FileCheck %s
22

33
protocol P {
4-
associatedtype A
4+
associatedtype A: P
55
}
66

77
// CHECK-LABEL: inferSameShape(ts:us:)
@@ -65,3 +65,21 @@ struct Ts<each T> {
6565
func expandedParameters<each T, each Result>(_ t: repeat each T, transform: repeat (each T) -> each Result) -> (repeat each Result) {
6666
fatalError()
6767
}
68+
69+
70+
//////
71+
///
72+
/// Same-type requirements should imply same-shape requirements.
73+
///
74+
//////
75+
76+
// CHECK-LABEL: sameType1
77+
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : P, repeat each U : P, repeat (each T).[P]A == (each U).[P]A>
78+
func sameType1<each T, each U>(_: repeat (each T, each U)) where repeat each T: P, repeat each U: P, repeat each T.A == each U.A {}
79+
80+
// Make sure inherited associated types are handled
81+
protocol Q: P where A: Q {}
82+
83+
// CHECK-LABEL: sameType2
84+
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : Q, repeat each U : Q, repeat (each T).[P]A.[P]A == (each U).[P]A.[P]A>
85+
func sameType2<each T, each U>(_: repeat (each T, each U)) where repeat each T: Q, repeat each U: Q, repeat each T.A.A == each U.A.A {}

0 commit comments

Comments
 (0)