Skip to content

Commit 22e24d3

Browse files
authored
Merge pull request #38284 from slavapestov/requirement-machine-canonical-types
RequirementMachine: Implement GenericSignature::getCanonicalTypeInContext()
2 parents 8530969 + 9a0c87b commit 22e24d3

12 files changed

+883
-60
lines changed

include/swift/AST/RequirementMachine.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace swift {
2424
class ASTContext;
2525
class AssociatedTypeDecl;
2626
class CanType;
27+
class GenericTypeParamType;
2728
class LayoutConstraint;
2829
class ProtocolDecl;
2930
class Requirement;
@@ -49,7 +50,7 @@ class RequirementMachine final {
4950
void addGenericSignature(CanGenericSignature sig);
5051

5152
bool isComplete() const;
52-
void computeCompletion(CanGenericSignature sig);
53+
void computeCompletion();
5354

5455
public:
5556
~RequirementMachine();
@@ -61,6 +62,8 @@ class RequirementMachine final {
6162
GenericSignature::RequiredProtocols getRequiredProtocols(Type depType) const;
6263
bool isConcreteType(Type depType) const;
6364
bool areSameTypeParameterInContext(Type depType1, Type depType2) const;
65+
Type getCanonicalTypeInContext(Type type,
66+
TypeArrayView<GenericTypeParamType> genericParams) const;
6467

6568
void dump(llvm::raw_ostream &out) const;
6669
};

lib/AST/GenericSignature.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,8 +743,41 @@ CanType GenericSignatureImpl::getCanonicalTypeInContext(Type type) const {
743743
if (!type->hasTypeParameter())
744744
return CanType(type);
745745

746-
auto &builder = *getGenericSignatureBuilder();
747-
return builder.getCanonicalTypeInContext(type, { })->getCanonicalType();
746+
auto computeViaGSB = [&]() {
747+
auto &builder = *getGenericSignatureBuilder();
748+
return builder.getCanonicalTypeInContext(type, { })->getCanonicalType();
749+
};
750+
751+
auto computeViaRQM = [&]() {
752+
auto *machine = getRequirementMachine();
753+
return machine->getCanonicalTypeInContext(type, { })->getCanonicalType();
754+
};
755+
756+
auto &ctx = getASTContext();
757+
if (ctx.LangOpts.EnableRequirementMachine) {
758+
auto rqmResult = computeViaRQM();
759+
760+
#ifndef NDEBUG
761+
auto gsbResult = computeViaGSB();
762+
763+
if (gsbResult != rqmResult) {
764+
llvm::errs() << "RequirementMachine::getCanonicalTypeInContext() is broken\n";
765+
llvm::errs() << "Generic signature: " << GenericSignature(this) << "\n";
766+
llvm::errs() << "Dependent type: "; type.dump(llvm::errs());
767+
llvm::errs() << "GenericSignatureBuilder says: " << gsbResult << "\n";
768+
gsbResult.dump(llvm::errs());
769+
llvm::errs() << "RequirementMachine says: " << rqmResult << "\n";
770+
rqmResult.dump(llvm::errs());
771+
llvm::errs() << "\n";
772+
getRequirementMachine()->dump(llvm::errs());
773+
abort();
774+
}
775+
#endif
776+
777+
return rqmResult;
778+
} else {
779+
return computeViaGSB();
780+
}
748781
}
749782

750783
ArrayRef<CanTypeWrapper<GenericTypeParamType>>

lib/AST/RequirementMachine/EquivalenceClassMap.cpp

Lines changed: 218 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
#include "swift/AST/Decl.h"
5252
#include "swift/AST/LayoutConstraint.h"
53+
#include "swift/AST/Module.h"
54+
#include "swift/AST/ProtocolConformance.h"
5355
#include "swift/AST/TypeMatcher.h"
5456
#include "swift/AST/Types.h"
5557
#include "llvm/Support/raw_ostream.h"
@@ -102,6 +104,45 @@ static unsigned getGenericParamIndex(Type type) {
102104
return paramTy->getIndex();
103105
}
104106

107+
/// Reverses the transformation performed by
108+
/// RewriteSystemBuilder::getConcreteSubstitutionSchema().
109+
static Type getTypeFromSubstitutionSchema(Type schema,
110+
ArrayRef<Term> substitutions,
111+
TypeArrayView<GenericTypeParamType> genericParams,
112+
const ProtocolGraph &protos,
113+
RewriteContext &ctx) {
114+
assert(!schema->isTypeParameter() && "Must have a concrete type here");
115+
116+
if (!schema->hasTypeParameter())
117+
return schema;
118+
119+
return schema.transformRec([&](Type t) -> Optional<Type> {
120+
if (t->is<GenericTypeParamType>()) {
121+
auto index = getGenericParamIndex(t);
122+
123+
return ctx.getTypeForTerm(substitutions[index],
124+
genericParams, protos);
125+
}
126+
127+
assert(!t->isTypeParameter());
128+
return None;
129+
});
130+
}
131+
132+
/// Get the concrete type of this equivalence class.
133+
///
134+
/// Asserts if this equivalence class is not concrete.
135+
Type EquivalenceClass::getConcreteType(
136+
TypeArrayView<GenericTypeParamType> genericParams,
137+
const ProtocolGraph &protos,
138+
RewriteContext &ctx) const {
139+
return getTypeFromSubstitutionSchema(ConcreteType->getConcreteType(),
140+
ConcreteType->getSubstitutions(),
141+
genericParams,
142+
protos,
143+
ctx);
144+
}
145+
105146
/// Given a concrete type that is a structural sub-component of a concrete
106147
/// type produced by RewriteSystemBuilder::getConcreteSubstitutionSchema(),
107148
/// collect the subset of referenced substitutions and renumber the generic
@@ -472,6 +513,157 @@ void EquivalenceClassMap::addProperty(
472513
inducedRules, DebugConcreteUnification);
473514
}
474515

516+
void EquivalenceClassMap::concretizeNestedTypesFromConcreteParents(
517+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const {
518+
for (const auto &equivClass : Map) {
519+
if (equivClass->isConcreteType() &&
520+
!equivClass->getConformsTo().empty()) {
521+
if (DebugConcretizeNestedTypes) {
522+
llvm::dbgs() << "^ Concretizing nested types of ";
523+
equivClass->dump(llvm::dbgs());
524+
llvm::dbgs() << "\n";
525+
}
526+
527+
concretizeNestedTypesFromConcreteParent(
528+
equivClass->getKey(),
529+
equivClass->ConcreteType->getConcreteType(),
530+
equivClass->ConcreteType->getSubstitutions(),
531+
equivClass->getConformsTo(),
532+
inducedRules);
533+
}
534+
}
535+
}
536+
537+
/// If we have an equivalence class T => { conforms_to: [ P ], concrete: Foo },
538+
/// then for each associated type A of P, we generate a new rule:
539+
///
540+
/// T.[P:A].[concrete: Foo.A] => T.[P:A] (if Foo.A is concrete)
541+
/// T.[P:A] => T.(Foo.A) (if Foo.A is abstract)
542+
///
543+
void EquivalenceClassMap::concretizeNestedTypesFromConcreteParent(
544+
const MutableTerm &key,
545+
CanType concreteType, ArrayRef<Term> substitutions,
546+
ArrayRef<const ProtocolDecl *> conformsTo,
547+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const {
548+
for (auto *proto : conformsTo) {
549+
// FIXME: Either remove the ModuleDecl entirely from conformance lookup,
550+
// or pass the correct one down in here.
551+
auto *module = proto->getParentModule();
552+
553+
auto conformance = module->lookupConformance(concreteType,
554+
const_cast<ProtocolDecl *>(proto));
555+
if (conformance.isInvalid()) {
556+
// FIXME: Diagnose conflict
557+
if (DebugConcretizeNestedTypes) {
558+
llvm::dbgs() << "^^ " << concreteType << " does not conform to "
559+
<< proto->getName() << "\n";
560+
}
561+
562+
continue;
563+
}
564+
565+
// FIXME: Maybe this can happen if the concrete type is an
566+
// opaque result type?
567+
assert(!conformance.isAbstract());
568+
569+
auto assocTypes = Protos.getProtocolInfo(proto).AssociatedTypes;
570+
if (assocTypes.empty())
571+
continue;
572+
573+
auto *concrete = conformance.getConcrete();
574+
575+
// We might have duplicates in the list due to diamond inheritance.
576+
// FIXME: Filter those out further upstream?
577+
// FIXME: This should actually be outside of the loop over the conforming protos...
578+
llvm::SmallDenseSet<AssociatedTypeDecl *, 4> visited;
579+
for (auto *assocType : assocTypes) {
580+
if (!visited.insert(assocType).second)
581+
continue;
582+
583+
// Get the actual protocol in case we inherited this associated type.
584+
auto *actualProto = assocType->getProtocol();
585+
if (actualProto != proto)
586+
continue;
587+
588+
if (DebugConcretizeNestedTypes) {
589+
llvm::dbgs() << "^^ " << "Looking up type witness for "
590+
<< proto->getName() << ":" << assocType->getName()
591+
<< " on " << concreteType << "\n";
592+
}
593+
594+
auto typeWitness = concrete->getTypeWitness(assocType)
595+
->getCanonicalType();
596+
597+
if (DebugConcretizeNestedTypes) {
598+
llvm::dbgs() << "^^ " << "Type witness for " << assocType->getName()
599+
<< " of " << concreteType << " is " << typeWitness << "\n";
600+
}
601+
602+
auto nestedType = Atom::forAssociatedType(proto, assocType->getName(),
603+
Context);
604+
605+
MutableTerm subjectType = key;
606+
subjectType.add(nestedType);
607+
608+
MutableTerm constraintType;
609+
610+
if (concreteType == typeWitness) {
611+
if (DebugConcretizeNestedTypes) {
612+
llvm::dbgs() << "^^ Type witness is the same as the concrete type\n";
613+
}
614+
615+
// Add a rule T.[P:A] => T.
616+
constraintType = key;
617+
618+
} else if (typeWitness->isTypeParameter()) {
619+
// The type witness is a type parameter of the form τ_0_n.X.Y...Z,
620+
// where 'n' is an index into the substitution array.
621+
//
622+
// Collect zero or more member type names in reverse order.
623+
SmallVector<Atom, 3> atoms;
624+
while (auto memberType = dyn_cast<DependentMemberType>(typeWitness)) {
625+
atoms.push_back(Atom::forName(memberType->getName(), Context));
626+
typeWitness = memberType.getBase();
627+
}
628+
629+
// Get the substitution S corresponding to τ_0_n.
630+
unsigned index = getGenericParamIndex(typeWitness);
631+
constraintType = MutableTerm(substitutions[index]);
632+
633+
// Add the member type names.
634+
std::reverse(atoms.begin(), atoms.end());
635+
for (auto atom : atoms)
636+
constraintType.add(atom);
637+
638+
// Add a rule T => S.X.Y...Z.
639+
640+
} else {
641+
// The type witness is a concrete type.
642+
constraintType = subjectType;
643+
644+
// FIXME: Handle dependent member types here
645+
SmallVector<Term, 3> result;
646+
auto typeWitnessSchema =
647+
remapConcreteSubstitutionSchema(typeWitness, substitutions,
648+
Context.getASTContext(),
649+
result);
650+
constraintType.add(
651+
Atom::forConcreteType(
652+
typeWitnessSchema, result, Context));
653+
654+
// Add a rule T.[P:A].[concrete: Foo.A] => T.[P:A].
655+
656+
}
657+
658+
inducedRules.emplace_back(subjectType, constraintType);
659+
if (DebugConcretizeNestedTypes) {
660+
llvm::dbgs() << "^^ Induced rule " << constraintType
661+
<< " => " << subjectType << "\n";
662+
}
663+
}
664+
}
665+
}
666+
475667
void EquivalenceClassMap::dump(llvm::raw_ostream &out) const {
476668
out << "Equivalence class map: {\n";
477669
for (const auto &equivClass : Map) {
@@ -485,10 +677,19 @@ void EquivalenceClassMap::dump(llvm::raw_ostream &out) const {
485677
/// Build the equivalence class map from all rules of the form T.[p] => T, where
486678
/// [p] is a property atom.
487679
///
488-
/// Returns true if concrete term unification performed while building the map
489-
/// introduced new rewrite rules; in this case, the completion procedure must
490-
/// run again.
491-
bool RewriteSystem::buildEquivalenceClassMap(EquivalenceClassMap &map) {
680+
/// Returns a pair consisting of a status and number of iterations executed.
681+
///
682+
/// The status is CompletionResult::MaxIterations if we exceed \p maxIterations
683+
/// iterations.
684+
///
685+
/// The status is CompletionResult::MaxDepth if we produce a rewrite rule whose
686+
/// left hand side has a length exceeding \p maxDepth.
687+
///
688+
/// Otherwise, the status is CompletionResult::Success.
689+
std::pair<RewriteSystem::CompletionResult, unsigned>
690+
RewriteSystem::buildEquivalenceClassMap(EquivalenceClassMap &map,
691+
unsigned maxIterations,
692+
unsigned maxDepth) {
492693
map.clear();
493694

494695
std::vector<std::pair<MutableTerm, Atom>> properties;
@@ -533,18 +734,25 @@ bool RewriteSystem::buildEquivalenceClassMap(EquivalenceClassMap &map) {
533734
map.addProperty(pair.first, pair.second, inducedRules);
534735
}
535736

737+
// We also need to merge concrete type rules with conformance rules, by
738+
// concretizing the associated type witnesses of the concrete type.
739+
map.concretizeNestedTypesFromConcreteParents(inducedRules);
740+
536741
// Some of the induced rules might be trivial; only count the induced rules
537742
// where the left hand side is not already equivalent to the right hand side.
538743
unsigned addedNewRules = 0;
539744
for (auto pair : inducedRules) {
540-
if (addRule(pair.first, pair.second))
745+
if (addRule(pair.first, pair.second)) {
541746
++addedNewRules;
542-
}
543747

544-
if (auto *stats = Context.getASTContext().Stats) {
545-
stats->getFrontendCounters()
546-
.NumRequirementMachineUnifiedConcreteTerms += addedNewRules;
748+
const auto &newRule = Rules.back();
749+
if (newRule.getLHS().size() > maxDepth)
750+
return std::make_pair(CompletionResult::MaxDepth, addedNewRules);
751+
}
547752
}
548753

549-
return addedNewRules > 0;
754+
if (Rules.size() > maxIterations)
755+
return std::make_pair(CompletionResult::MaxIterations, addedNewRules);
756+
757+
return std::make_pair(CompletionResult::Success, addedNewRules);
550758
}

lib/AST/RequirementMachine/EquivalenceClassMap.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class EquivalenceClass {
8686
return ConcreteType.hasValue();
8787
}
8888

89+
Type getConcreteType(
90+
TypeArrayView<GenericTypeParamType> genericParams,
91+
const ProtocolGraph &protos,
92+
RewriteContext &ctx) const;
93+
8994
LayoutConstraint getLayoutConstraint() const {
9095
return Layout;
9196
}
@@ -103,7 +108,8 @@ class EquivalenceClassMap {
103108
RewriteContext &Context;
104109
std::vector<std::unique_ptr<EquivalenceClass>> Map;
105110
const ProtocolGraph &Protos;
106-
bool DebugConcreteUnification = false;
111+
unsigned DebugConcreteUnification : 1;
112+
unsigned DebugConcretizeNestedTypes : 1;
107113

108114
EquivalenceClass *getEquivalenceClassIfPresent(const MutableTerm &key) const;
109115
EquivalenceClass *getOrCreateEquivalenceClass(const MutableTerm &key);
@@ -116,14 +122,28 @@ class EquivalenceClassMap {
116122
public:
117123
explicit EquivalenceClassMap(RewriteContext &ctx,
118124
const ProtocolGraph &protos)
119-
: Context(ctx), Protos(protos) {}
125+
: Context(ctx), Protos(protos) {
126+
DebugConcreteUnification = false;
127+
DebugConcretizeNestedTypes = false;
128+
}
120129

121130
EquivalenceClass *lookUpEquivalenceClass(const MutableTerm &key) const;
122131

132+
void dump(llvm::raw_ostream &out) const;
133+
123134
void clear();
124135
void addProperty(const MutableTerm &key, Atom property,
125136
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules);
126-
void dump(llvm::raw_ostream &out) const;
137+
void concretizeNestedTypesFromConcreteParents(
138+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const;
139+
140+
private:
141+
void concretizeNestedTypesFromConcreteParent(
142+
const MutableTerm &key,
143+
CanType concreteType,
144+
ArrayRef<Term> substitutions,
145+
ArrayRef<const ProtocolDecl *> conformsTo,
146+
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const;
127147
};
128148

129149
} // end namespace rewriting

0 commit comments

Comments
 (0)