Skip to content

Commit f0d59d5

Browse files
committed
RequirementMachine: Handle equivalence classes that both have a concrete type and protocol requirements
If a type parameter has a protocol conformance and a concrete type, we want to map associated types of the conformance to their concrete type witnesses. This is implemented as a post-processing pass in the completion procedure that runs after the equivalence class map has been built. If we have an equivalence class T => { conforms_to: [ P ], concrete: Foo }, then for each associated type A of P, we generate a new rule: T.[P:A].[concrete: Foo.A] => T.[P:A] (if Foo.A is concrete) T.[P:A] => T.(Foo.A) (if Foo.A is abstract) If this process introduced any new rules, we check for any new overlaps by re-running Knuth-Bendix completion; this may in turn introduce new concrete associated type overlaps, and so on. The overall completion procedure now alternates between Knuth-Bendix and rebuilding the equivalence class map; the rewrite system is complete when neither step is able to introduce any new rules.
1 parent 35daf17 commit f0d59d5

File tree

6 files changed

+234
-5
lines changed

6 files changed

+234
-5
lines changed

lib/AST/RequirementMachine/EquivalenceClassMap.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
#include "swift/AST/Decl.h"
5252
#include "swift/AST/LayoutConstraint.h"
53+
#include "swift/AST/Module.h"
54+
#include "swift/AST/ProtocolConformance.h"
5355
#include "swift/AST/TypeMatcher.h"
5456
#include "swift/AST/Types.h"
5557
#include "llvm/Support/raw_ostream.h"
@@ -511,6 +513,157 @@ void EquivalenceClassMap::addProperty(
511513
inducedRules, DebugConcreteUnification);
512514
}
513515

516+
void EquivalenceClassMap::concretizeNestedTypesFromConcreteParents(
517+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const {
518+
for (const auto &equivClass : Map) {
519+
if (equivClass->isConcreteType() &&
520+
!equivClass->getConformsTo().empty()) {
521+
if (DebugConcretizeNestedTypes) {
522+
llvm::dbgs() << "^ Concretizing nested types of ";
523+
equivClass->dump(llvm::dbgs());
524+
llvm::dbgs() << "\n";
525+
}
526+
527+
concretizeNestedTypesFromConcreteParent(
528+
equivClass->getKey(),
529+
equivClass->ConcreteType->getConcreteType(),
530+
equivClass->ConcreteType->getSubstitutions(),
531+
equivClass->getConformsTo(),
532+
inducedRules);
533+
}
534+
}
535+
}
536+
537+
/// If we have an equivalence class T => { conforms_to: [ P ], concrete: Foo },
538+
/// then for each associated type A of P, we generate a new rule:
539+
///
540+
/// T.[P:A].[concrete: Foo.A] => T.[P:A] (if Foo.A is concrete)
541+
/// T.[P:A] => T.(Foo.A) (if Foo.A is abstract)
542+
///
543+
void EquivalenceClassMap::concretizeNestedTypesFromConcreteParent(
544+
const MutableTerm &key,
545+
CanType concreteType, ArrayRef<Term> substitutions,
546+
ArrayRef<const ProtocolDecl *> conformsTo,
547+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const {
548+
for (auto *proto : conformsTo) {
549+
// FIXME: Either remove the ModuleDecl entirely from conformance lookup,
550+
// or pass the correct one down in here.
551+
auto *module = proto->getParentModule();
552+
553+
auto conformance = module->lookupConformance(concreteType,
554+
const_cast<ProtocolDecl *>(proto));
555+
if (conformance.isInvalid()) {
556+
// FIXME: Diagnose conflict
557+
if (DebugConcretizeNestedTypes) {
558+
llvm::dbgs() << "^^ " << concreteType << " does not conform to "
559+
<< proto->getName() << "\n";
560+
}
561+
562+
continue;
563+
}
564+
565+
// FIXME: Maybe this can happen if the concrete type is an
566+
// opaque result type?
567+
assert(!conformance.isAbstract());
568+
569+
auto assocTypes = Protos.getProtocolInfo(proto).AssociatedTypes;
570+
if (assocTypes.empty())
571+
continue;
572+
573+
auto *concrete = conformance.getConcrete();
574+
575+
// We might have duplicates in the list due to diamond inheritance.
576+
// FIXME: Filter those out further upstream?
577+
// FIXME: This should actually be outside of the loop over the conforming protos...
578+
llvm::SmallDenseSet<AssociatedTypeDecl *, 4> visited;
579+
for (auto *assocType : assocTypes) {
580+
if (!visited.insert(assocType).second)
581+
continue;
582+
583+
// Get the actual protocol in case we inherited this associated type.
584+
auto *actualProto = assocType->getProtocol();
585+
if (actualProto != proto)
586+
continue;
587+
588+
if (DebugConcretizeNestedTypes) {
589+
llvm::dbgs() << "^^ " << "Looking up type witness for "
590+
<< proto->getName() << ":" << assocType->getName()
591+
<< " on " << concreteType << "\n";
592+
}
593+
594+
auto typeWitness = concrete->getTypeWitness(assocType)
595+
->getCanonicalType();
596+
597+
if (DebugConcretizeNestedTypes) {
598+
llvm::dbgs() << "^^ " << "Type witness for " << assocType->getName()
599+
<< " of " << concreteType << " is " << typeWitness << "\n";
600+
}
601+
602+
auto nestedType = Atom::forAssociatedType(proto, assocType->getName(),
603+
Context);
604+
605+
MutableTerm subjectType = key;
606+
subjectType.add(nestedType);
607+
608+
MutableTerm constraintType;
609+
610+
if (concreteType == typeWitness) {
611+
if (DebugConcretizeNestedTypes) {
612+
llvm::dbgs() << "^^ Type witness is the same as the concrete type\n";
613+
}
614+
615+
// Add a rule T.[P:A] => T.
616+
constraintType = key;
617+
618+
} else if (typeWitness->isTypeParameter()) {
619+
// The type witness is a type parameter of the form τ_0_n.X.Y...Z,
620+
// where 'n' is an index into the substitution array.
621+
//
622+
// Collect zero or more member type names in reverse order.
623+
SmallVector<Atom, 3> atoms;
624+
while (auto memberType = dyn_cast<DependentMemberType>(typeWitness)) {
625+
atoms.push_back(Atom::forName(memberType->getName(), Context));
626+
typeWitness = memberType.getBase();
627+
}
628+
629+
// Get the substitution S corresponding to τ_0_n.
630+
unsigned index = getGenericParamIndex(typeWitness);
631+
constraintType = MutableTerm(substitutions[index]);
632+
633+
// Add the member type names.
634+
std::reverse(atoms.begin(), atoms.end());
635+
for (auto atom : atoms)
636+
constraintType.add(atom);
637+
638+
// Add a rule T => S.X.Y...Z.
639+
640+
} else {
641+
// The type witness is a concrete type.
642+
constraintType = subjectType;
643+
644+
// FIXME: Handle dependent member types here
645+
SmallVector<Term, 3> result;
646+
auto typeWitnessSchema =
647+
remapConcreteSubstitutionSchema(typeWitness, substitutions,
648+
Context.getASTContext(),
649+
result);
650+
constraintType.add(
651+
Atom::forConcreteType(
652+
typeWitnessSchema, result, Context));
653+
654+
// Add a rule T.[P:A].[concrete: Foo.A] => T.[P:A].
655+
656+
}
657+
658+
inducedRules.emplace_back(subjectType, constraintType);
659+
if (DebugConcretizeNestedTypes) {
660+
llvm::dbgs() << "^^ Induced rule " << constraintType
661+
<< " => " << subjectType << "\n";
662+
}
663+
}
664+
}
665+
}
666+
514667
void EquivalenceClassMap::dump(llvm::raw_ostream &out) const {
515668
out << "Equivalence class map: {\n";
516669
for (const auto &equivClass : Map) {
@@ -581,6 +734,10 @@ RewriteSystem::buildEquivalenceClassMap(EquivalenceClassMap &map,
581734
map.addProperty(pair.first, pair.second, inducedRules);
582735
}
583736

737+
// We also need to merge concrete type rules with conformance rules, by
738+
// concretizing the associated type witnesses of the concrete type.
739+
map.concretizeNestedTypesFromConcreteParents(inducedRules);
740+
584741
// Some of the induced rules might be trivial; only count the induced rules
585742
// where the left hand side is not already equivalent to the right hand side.
586743
unsigned addedNewRules = 0;

lib/AST/RequirementMachine/EquivalenceClassMap.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ class EquivalenceClassMap {
108108
RewriteContext &Context;
109109
std::vector<std::unique_ptr<EquivalenceClass>> Map;
110110
const ProtocolGraph &Protos;
111-
bool DebugConcreteUnification = false;
111+
unsigned DebugConcreteUnification : 1;
112+
unsigned DebugConcretizeNestedTypes : 1;
112113

113114
EquivalenceClass *getEquivalenceClassIfPresent(const MutableTerm &key) const;
114115
EquivalenceClass *getOrCreateEquivalenceClass(const MutableTerm &key);
@@ -121,14 +122,28 @@ class EquivalenceClassMap {
121122
public:
122123
explicit EquivalenceClassMap(RewriteContext &ctx,
123124
const ProtocolGraph &protos)
124-
: Context(ctx), Protos(protos) {}
125+
: Context(ctx), Protos(protos) {
126+
DebugConcreteUnification = false;
127+
DebugConcretizeNestedTypes = false;
128+
}
125129

126130
EquivalenceClass *lookUpEquivalenceClass(const MutableTerm &key) const;
127131

132+
void dump(llvm::raw_ostream &out) const;
133+
128134
void clear();
129135
void addProperty(const MutableTerm &key, Atom property,
130136
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules);
131-
void dump(llvm::raw_ostream &out) const;
137+
void concretizeNestedTypesFromConcreteParents(
138+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const;
139+
140+
private:
141+
void concretizeNestedTypesFromConcreteParent(
142+
const MutableTerm &key,
143+
CanType concreteType,
144+
ArrayRef<Term> substitutions,
145+
ArrayRef<const ProtocolDecl *> conformsTo,
146+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const;
132147
};
133148

134149
} // end namespace rewriting

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Identifier Atom::getName() const {
133133
return Ptr->Name;
134134
}
135135

136-
/// Get the single protocol declaration associate with a protocol atom.
136+
/// Get the single protocol declaration associated with a protocol atom.
137137
const ProtocolDecl *Atom::getProtocol() const {
138138
assert(getKind() == Kind::Protocol);
139139
return Ptr->Proto;

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ void RewriteSystem::processMergedAssociatedTypes() {
334334
// merged type [P1&P2:T] must conform to Q as well. Add a new rule
335335
// of the form:
336336
//
337-
// [P1&P2].[Q] => [P1&P2]
337+
// [P1&P2:T].[Q] => [P1&P2:T]
338338
//
339339
MutableTerm newLHS;
340340
newLHS.add(mergedAtom);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %target-typecheck-verify-swift -enable-requirement-machine -debug-requirement-machine 2>&1 | %FileCheck %s
2+
3+
protocol P1 {
4+
associatedtype T : P1
5+
}
6+
7+
protocol P2 {
8+
associatedtype T where T == Int
9+
}
10+
11+
extension Int : P1 {
12+
public typealias T = Int
13+
}
14+
15+
struct G<T : P1 & P2> {}
16+
17+
// Since G.T.T == G.T.T.T == G.T.T.T.T = ... = Int, we tie off the
18+
// recursion by introducing a same-type requirement G.T.T => G.T.
19+
20+
// CHECK-LABEL: Adding generic signature <τ_0_0 where τ_0_0 : P1, τ_0_0 : P2> {
21+
// CHECK-LABEL: Rewrite system: {
22+
// CHECK: - τ_0_0.[P1&P2:T].[concrete: Int] => τ_0_0.[P1&P2:T]
23+
// CHECK: - [P1&P2:T].T => [P1&P2:T].[P1:T]
24+
// CHECK: - τ_0_0.[P1&P2:T].[P1:T] => τ_0_0.[P1&P2:T]
25+
// CHECK: }
26+
// CHECK: }
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %target-typecheck-verify-swift -enable-requirement-machine -debug-requirement-machine 2>&1 | %FileCheck %s
2+
3+
protocol P1 {
4+
associatedtype T : P1
5+
}
6+
7+
protocol P2 {
8+
associatedtype T where T == X<U>
9+
associatedtype U
10+
}
11+
12+
extension Int : P1 {
13+
public typealias T = Int
14+
}
15+
16+
struct X<A> : P1 {
17+
typealias T = X<A>
18+
}
19+
20+
struct G<T : P1 & P2> {}
21+
22+
// Since G.T.T == G.T.T.T == G.T.T.T.T = ... = X<T.U>, we tie off the
23+
// recursion by introducing a same-type requirement G.T.T => G.T.
24+
25+
// CHECK-LABEL: Adding generic signature <τ_0_0 where τ_0_0 : P1, τ_0_0 : P2> {
26+
// CHECK-LABEL: Rewrite system: {
27+
// CHECK: - τ_0_0.[P1&P2:T].[concrete: X<τ_0_0> with <τ_0_0.[P2:U]>] => τ_0_0.[P1&P2:T]
28+
// CHECK: - [P1&P2:T].T => [P1&P2:T].[P1:T]
29+
// CHECK: - τ_0_0.[P1&P2:T].[P1:T] => τ_0_0.[P1&P2:T]
30+
// CHECK: }
31+
// CHECK: }

0 commit comments

Comments
 (0)