Skip to content

Commit 438c067

Browse files
authored
Merge pull request #67064 from slavapestov/rqm-same-type-same-shape
RequirementMachine: Same-type requirements imply same-shape requirements
2 parents 07f9384 + e01822c commit 438c067

File tree

10 files changed

+163
-42
lines changed

10 files changed

+163
-42
lines changed

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -721,20 +721,27 @@ MutableTerm
721721
RequirementMachine::getReducedShapeTerm(Type type) const {
722722
assert(type->isParameterPack());
723723

724-
auto rootType = type->getRootGenericParam();
725-
auto term = Context.getMutableTermForType(rootType->getCanonicalType(),
724+
auto term = Context.getMutableTermForType(type->getCanonicalType(),
726725
/*proto=*/nullptr);
727726

728-
// Append the 'shape' symbol to the term.
727+
// From a type term T, form the shape term `T.[shape]`.
729728
term.add(Symbol::forShape(Context));
730729

730+
// Compute the reduced shape term `T'.[shape]`.
731731
System.simplify(term);
732732
verify(term);
733733

734-
// Remove the 'shape' symbol from the term.
735-
assert(term.back().getKind() == Symbol::Kind::Shape);
736-
MutableTerm reducedTerm(term.begin(), term.end() - 1);
734+
// Get the term T', which is the reduced shape of T.
735+
if (term.size() != 2 ||
736+
term[0].getKind() != Symbol::Kind::GenericParam ||
737+
term[1].getKind() != Symbol::Kind::Shape) {
738+
llvm::errs() << "Invalid reduced shape\n";
739+
llvm::errs() << "Type: " << type << "\n";
740+
llvm::errs() << "Term: " << term << "\n";
741+
abort();
742+
}
737743

744+
MutableTerm reducedTerm(term.begin(), term.end() - 1);
738745
return reducedTerm;
739746
}
740747

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,7 @@ static void desugarConformanceRequirement(Requirement req,
362362
desugarRequirement(subReq, loc, result, errors);
363363
}
364364

365-
/// Desugar same-shape requirements by equating the shapes of the
366-
/// root pack types, and diagnose shape requirements on non-pack
367-
/// types.
365+
/// Diagnose shape requirements on non-pack types.
368366
static void desugarSameShapeRequirement(Requirement req, SourceLoc loc,
369367
SmallVectorImpl<Requirement> &result,
370368
SmallVectorImpl<RequirementError> &errors) {
@@ -376,8 +374,7 @@ static void desugarSameShapeRequirement(Requirement req, SourceLoc loc,
376374
}
377375

378376
result.emplace_back(RequirementKind::SameShape,
379-
req.getFirstType()->getRootGenericParam(),
380-
req.getSecondType()->getRootGenericParam());
377+
req.getFirstType(), req.getSecondType());
381378
}
382379

383380
/// Convert a requirement where the subject type might not be a type parameter,

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ RequirementMachine::initWithProtocolSignatureRequirements(
307307
///
308308
/// Returns failure if completion fails within the configured number of steps.
309309
std::pair<CompletionResult, unsigned>
310-
RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
310+
RequirementMachine::initWithGenericSignature(GenericSignature sig) {
311311
Sig = sig;
312312
Params.append(sig.getGenericParams().begin(),
313313
sig.getGenericParams().end());
@@ -323,7 +323,8 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
323323
// Collect the top-level requirements, and all transitively-referenced
324324
// protocol requirement signatures.
325325
RuleBuilder builder(Context, System.getReferencedProtocols());
326-
builder.initWithGenericSignatureRequirements(sig.getRequirements());
326+
builder.initWithGenericSignature(sig.getGenericParams(),
327+
sig.getRequirements());
327328

328329
// Add the initial set of rewrite rules to the rewrite system.
329330
System.initialize(/*recordLoops=*/false,
@@ -425,7 +426,7 @@ RequirementMachine::initWithWrittenRequirements(
425426
// Collect the top-level requirements, and all transitively-referenced
426427
// protocol requirement signatures.
427428
RuleBuilder builder(Context, System.getReferencedProtocols());
428-
builder.initWithWrittenRequirements(requirements);
429+
builder.initWithWrittenRequirements(genericParams, requirements);
429430

430431
// Add the initial set of rewrite rules to the rewrite system.
431432
System.initialize(/*recordLoops=*/true,

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class RequirementMachine final {
5454
friend class swift::AbstractGenericSignatureRequest;
5555
friend class swift::InferredGenericSignatureRequest;
5656

57-
CanGenericSignature Sig;
57+
GenericSignature Sig;
5858
SmallVector<GenericTypeParamType *, 2> Params;
5959

6060
RewriteContext &Context;
@@ -95,7 +95,7 @@ class RequirementMachine final {
9595
ArrayRef<const ProtocolDecl *> protos);
9696

9797
std::pair<CompletionResult, unsigned>
98-
initWithGenericSignature(CanGenericSignature sig);
98+
initWithGenericSignature(GenericSignature sig);
9999

100100
std::pair<CompletionResult, unsigned>
101101
initWithProtocolWrittenRequirements(

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,16 @@ void RewriteSystem::recordRewriteLoop(MutableTerm basepoint,
504504
return;
505505

506506
// Ignore the rewrite loop if it is not part of our minimization domain.
507-
if (!isInMinimizationDomain(basepoint.getRootProtocol()))
507+
//
508+
// Completion might record a rewrite loop where the basepoint is just
509+
// the term [shape]. In this case though, we know it's in our domain,
510+
// since completion only checks local rules for overlap. Other callers
511+
// of recordRewriteLoop() always pass in a valid basepoint, so we
512+
// check.
513+
if (basepoint[0].getKind() != Symbol::Kind::Shape &&
514+
!isInMinimizationDomain(basepoint.getRootProtocol())) {
508515
return;
516+
}
509517

510518
Loops.push_back(loop);
511519
}
@@ -555,11 +563,6 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
555563
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
556564
}
557565

558-
// A shape symbol must follow a generic param symbol
559-
if (symbol.getKind() == Symbol::Kind::Shape) {
560-
ASSERT_RULE(index > 0 && lhs[index - 1].getKind() == Symbol::Kind::GenericParam);
561-
}
562-
563566
if (!rule.isLHSSimplified() &&
564567
index != lhs.size() - 1) {
565568
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteConformance);
@@ -602,15 +605,10 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
602605
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Superclass);
603606
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteType);
604607

605-
if (index != lhs.size() - 1) {
608+
if (index != rhs.size() - 1) {
606609
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
607610
}
608611

609-
// A shape symbol must follow a generic param symbol
610-
if (symbol.getKind() == Symbol::Kind::Shape) {
611-
ASSERT_RULE(index > 0 && rhs[index - 1].getKind() == Symbol::Kind::GenericParam);
612-
}
613-
614612
// Completion can introduce a rule of the form
615613
//
616614
// (T.[P] => T.[concrete: C : P])
@@ -635,10 +633,15 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
635633
}
636634
}
637635

638-
auto lhsDomain = lhs.getRootProtocol();
639-
auto rhsDomain = rhs.getRootProtocol();
640-
641-
ASSERT_RULE(lhsDomain == rhsDomain);
636+
if (rhs.size() == 1 && rhs[0].getKind() == Symbol::Kind::Shape) {
637+
// We can have a rule like T.[shape] => [shape].
638+
ASSERT_RULE(lhs.back().getKind() == Symbol::Kind::Shape);
639+
} else {
640+
// Otherwise, LHS and RHS must have the same domain.
641+
auto lhsDomain = lhs.getRootProtocol();
642+
auto rhsDomain = rhs.getRootProtocol();
643+
ASSERT_RULE(lhsDomain == rhsDomain);
644+
}
642645
}
643646

644647
#undef ASSERT_RULE
@@ -709,6 +712,7 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const {
709712
loop.dump(out, *this);
710713
out << "\n";
711714
}
715+
out << "}\n";
712716
}
713717
if (!WrittenRequirements.empty()) {
714718
out << "Written requirements: {\n";

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ using namespace rewriting;
3434

3535
/// For building a rewrite system for a generic signature from canonical
3636
/// requirements.
37-
void RuleBuilder::initWithGenericSignatureRequirements(
37+
void RuleBuilder::initWithGenericSignature(
38+
ArrayRef<GenericTypeParamType *> genericParams,
3839
ArrayRef<Requirement> requirements) {
3940
assert(!Initialized);
4041
Initialized = 1;
@@ -47,6 +48,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
4748
}
4849

4950
collectRulesFromReferencedProtocols();
51+
collectPackShapeRules(genericParams);
5052

5153
// Add rewrite rules for all top-level requirements.
5254
for (const auto &req : requirements)
@@ -56,6 +58,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
5658
/// For building a rewrite system for a generic signature from user-written
5759
/// requirements.
5860
void RuleBuilder::initWithWrittenRequirements(
61+
ArrayRef<GenericTypeParamType *> genericParams,
5962
ArrayRef<StructuralRequirement> requirements) {
6063
assert(!Initialized);
6164
Initialized = 1;
@@ -68,6 +71,7 @@ void RuleBuilder::initWithWrittenRequirements(
6871
}
6972

7073
collectRulesFromReferencedProtocols();
74+
collectPackShapeRules(genericParams);
7175

7276
// Add rewrite rules for all top-level requirements.
7377
for (const auto &req : requirements)
@@ -488,3 +492,77 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
488492
localRules.end());
489493
}
490494
}
495+
496+
void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams) {
497+
if (Dump) {
498+
llvm::dbgs() << "adding shape rules\n";
499+
}
500+
501+
if (!llvm::any_of(genericParams,
502+
[](GenericTypeParamType *t) {
503+
return t->isParameterPack();
504+
})) {
505+
return;
506+
}
507+
508+
// Each non-pack generic parameter is part of the "scalar shape class", represented
509+
// by the empty term.
510+
for (auto *genericParam : genericParams) {
511+
if (genericParam->isParameterPack())
512+
continue;
513+
514+
// Add the rule (τ_d_i.[shape] => [shape]).
515+
MutableTerm lhs;
516+
lhs.add(Symbol::forGenericParam(
517+
cast<GenericTypeParamType>(genericParam->getCanonicalType()), Context));
518+
lhs.add(Symbol::forShape(Context));
519+
520+
MutableTerm rhs;
521+
rhs.add(Symbol::forShape(Context));
522+
523+
PermanentRules.emplace_back(lhs, rhs);
524+
}
525+
526+
// A member type T.[P:A] is part of the same shape class as its base type T.
527+
llvm::DenseSet<Symbol> visited;
528+
529+
auto addMemberShapeRule = [&](const ProtocolDecl *proto, AssociatedTypeDecl *assocType) {
530+
auto symbol = Symbol::forAssociatedType(proto, assocType->getName(), Context);
531+
if (!visited.insert(symbol).second)
532+
return;
533+
534+
// Add the rule ([P:A].[shape] => [shape]).
535+
MutableTerm lhs;
536+
lhs.add(symbol);
537+
lhs.add(Symbol::forShape(Context));
538+
539+
MutableTerm rhs;
540+
rhs.add(Symbol::forShape(Context));
541+
542+
// Consider it an imported rule, since it is not part of our minimization
543+
// domain. It would be more logical if we added these in the protocol component
544+
// machine for this protocol, but instead we add them in the "leaf" generic
545+
// signature machine. This avoids polluting machines that do not involve
546+
// parameter packs with these extra rules, which would otherwise just slow
547+
// things down.
548+
Rule rule(Term::get(lhs, Context), Term::get(rhs, Context));
549+
rule.markPermanent();
550+
ImportedRules.push_back(rule);
551+
};
552+
553+
for (auto *proto : ProtocolsToImport) {
554+
if (Dump) {
555+
llvm::dbgs() << "adding member shape rules for protocol " << proto->getName() << "\n";
556+
}
557+
558+
for (auto *assocType : proto->getAssociatedTypeMembers()) {
559+
addMemberShapeRule(proto, assocType);
560+
}
561+
562+
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
563+
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
564+
addMemberShapeRule(proto, assocType);
565+
}
566+
}
567+
}
568+
}

lib/AST/RequirementMachine/RuleBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ struct RuleBuilder {
9595
Initialized = 0;
9696
}
9797

98-
void initWithGenericSignatureRequirements(ArrayRef<Requirement> requirements);
99-
void initWithWrittenRequirements(ArrayRef<StructuralRequirement> requirements);
98+
void initWithGenericSignature(ArrayRef<GenericTypeParamType *> genericParams,
99+
ArrayRef<Requirement> requirements);
100+
void initWithWrittenRequirements(ArrayRef<GenericTypeParamType *> genericParams,
101+
ArrayRef<StructuralRequirement> requirements);
100102
void initWithProtocolSignatureRequirements(ArrayRef<const ProtocolDecl *> proto);
101103
void initWithProtocolWrittenRequirements(
102104
ArrayRef<const ProtocolDecl *> component,
@@ -106,6 +108,7 @@ struct RuleBuilder {
106108
ArrayRef<Term> substitutions);
107109
void addReferencedProtocol(const ProtocolDecl *proto);
108110
void collectRulesFromReferencedProtocols();
111+
void collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams);
109112

110113
private:
111114
void addPermanentProtocolRules(const ProtocolDecl *proto);

lib/AST/RequirementMachine/Symbol.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,12 +700,7 @@ void Symbol::dump(llvm::raw_ostream &out) const {
700700
}
701701

702702
case Kind::GenericParam: {
703-
auto *gp = getGenericParam();
704-
if (gp->isParameterPack()) {
705-
out << "(" << Type(gp) << "…)";
706-
} else {
707-
out << Type(gp);
708-
}
703+
out << Type(getGenericParam());
709704
return;
710705
}
711706

test/Generics/pack-shape-requirements.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -disable-availability-checking 2>&1 | %FileCheck %s
22

33
protocol P {
4-
associatedtype A
4+
associatedtype A: P
55
}
66

77
// CHECK-LABEL: inferSameShape(ts:us:)
@@ -65,3 +65,21 @@ struct Ts<each T> {
6565
func expandedParameters<each T, each Result>(_ t: repeat each T, transform: repeat (each T) -> each Result) -> (repeat each Result) {
6666
fatalError()
6767
}
68+
69+
70+
//////
71+
///
72+
/// Same-type requirements should imply same-shape requirements.
73+
///
74+
//////
75+
76+
// CHECK-LABEL: sameType1
77+
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : P, repeat each U : P, repeat (each T).[P]A == (each U).[P]A>
78+
func sameType1<each T, each U>(_: repeat (each T, each U)) where repeat each T: P, repeat each U: P, repeat each T.A == each U.A {}
79+
80+
// Make sure inherited associated types are handled
81+
protocol Q: P where A: Q {}
82+
83+
// CHECK-LABEL: sameType2
84+
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : Q, repeat each U : Q, repeat (each T).[P]A.[P]A == (each U).[P]A.[P]A>
85+
func sameType2<each T, each U>(_: repeat (each T, each U)) where repeat each T: Q, repeat each U: Q, repeat each T.A.A == each U.A.A {}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %target-swift-frontend -emit-ir %s
2+
3+
public protocol P {}
4+
5+
public protocol Q {
6+
associatedtype A: P
7+
}
8+
9+
public func f<T: P>(_: T) {}
10+
11+
public func foo1<each T: Q, each U>(t: repeat each T, u: repeat each U)
12+
where repeat (each U) == (each T).A {
13+
repeat f(each u)
14+
}
15+
16+
public func foo2<each T: Q>(t: repeat each T, u: repeat each T.A) {
17+
repeat f(each u)
18+
}

0 commit comments

Comments
 (0)