Skip to content

Commit d3702ba

Browse files
authored
Merge pull request #41664 from hborla/redundant-requirements
2 parents 07bb2b1 + 7ce6504 commit d3702ba

18 files changed

+397
-60
lines changed

lib/AST/RequirementMachine/ConcreteTypeWitness.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,11 @@ void PropertyMap::inferConditionalRequirements(
578578
for (const auto &rule : builder.PermanentRules)
579579
System.addPermanentRule(rule.first, rule.second);
580580

581-
for (const auto &rule : builder.RequirementRules)
582-
System.addExplicitRule(rule.first, rule.second);
581+
for (const auto &rule : builder.RequirementRules) {
582+
auto lhs = std::get<0>(rule);
583+
auto rhs = std::get<1>(rule);
584+
System.addExplicitRule(lhs, rhs, /*requirementID=*/None);
585+
}
583586
}
584587
}
585588

lib/AST/RequirementMachine/Debug.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ enum class DebugFlags : unsigned {
6464
RedundantRulesDetail = (1<<13),
6565

6666
/// Print debug output from the concrete contraction pre-processing pass.
67-
ConcreteContraction = (1<<14)
67+
ConcreteContraction = (1<<14),
68+
69+
/// Print debug output from propagating explicit requirement
70+
/// IDs from redundant rules.
71+
PropagateRequirementIDs = (1<<15),
6872
};
6973

7074
using DebugOptions = OptionSet<DebugFlags>;

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,14 @@ void RewriteLoop::recompute(const RewriteSystem &system) {
7474

7575
ProjectionCount = 0;
7676
DecomposeCount = 0;
77-
78-
// Rules appearing in empty context (possibly more than once).
79-
llvm::SmallDenseSet<unsigned, 2> rulesInEmptyContext;
80-
81-
// The number of times each rule appears (with or without context).
82-
llvm::SmallDenseMap<unsigned, unsigned, 2> ruleMultiplicity;
77+
Useful = false;
8378

8479
RewritePathEvaluator evaluator(Basepoint);
85-
8680
for (auto step : Path) {
8781
switch (step.Kind) {
88-
case RewriteStep::Rule: {
89-
if (!step.isInContext() && !evaluator.isInContext())
90-
rulesInEmptyContext.insert(step.getRuleID());
91-
92-
++ruleMultiplicity[step.getRuleID()];
82+
case RewriteStep::Rule:
83+
Useful |= (!step.isInContext() && !evaluator.isInContext());
9384
break;
94-
}
9585

9686
case RewriteStep::LeftConcreteProjection:
9787
++ProjectionCount;
@@ -112,18 +102,7 @@ void RewriteLoop::recompute(const RewriteSystem &system) {
112102
evaluator.apply(step, system);
113103
}
114104

115-
Useful = !rulesInEmptyContext.empty();
116-
117-
RulesInEmptyContext.clear();
118-
119-
// Collect all rules that we saw exactly once in empty context.
120-
for (auto rule : rulesInEmptyContext) {
121-
auto found = ruleMultiplicity.find(rule);
122-
assert(found != ruleMultiplicity.end());
123-
124-
if (found->second == 1)
125-
RulesInEmptyContext.push_back(rule);
126-
}
105+
RulesInEmptyContext = Path.getRulesInEmptyContext(Basepoint, system);
127106
}
128107

129108
/// A rewrite rule is redundant if it appears exactly once in a loop
@@ -211,6 +190,46 @@ void RewriteSystem::propagateExplicitBits() {
211190
}
212191
}
213192

193+
/// Propagate requirement IDs from redundant rules to their
194+
/// replacements that appear once in empty context.
195+
void RewriteSystem::propagateRedundantRequirementIDs() {
196+
if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
197+
llvm::dbgs() << "\nPropagating requirement IDs: {";
198+
}
199+
200+
for (auto ruleAndReplacement : RedundantRules) {
201+
auto ruleID = ruleAndReplacement.first;
202+
auto rewritePath = ruleAndReplacement.second;
203+
auto &rule = Rules[ruleID];
204+
205+
auto requirementID = rule.getRequirementID();
206+
if (!requirementID.hasValue())
207+
continue;
208+
209+
MutableTerm lhs(rule.getLHS());
210+
for (auto ruleID : rewritePath.getRulesInEmptyContext(lhs, *this)) {
211+
auto &replacement = Rules[ruleID];
212+
if (!replacement.isPermanent() &&
213+
!replacement.getRequirementID().hasValue()) {
214+
if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
215+
llvm::dbgs() << "\n- propagating ID = "
216+
<< requirementID
217+
<< "\n from ";
218+
rule.dump(llvm::dbgs());
219+
llvm::dbgs() << "\n to ";
220+
replacement.dump(llvm::dbgs());
221+
}
222+
223+
replacement.setRequirementID(requirementID);
224+
}
225+
}
226+
}
227+
228+
if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
229+
llvm::dbgs() << "\n}\n";
230+
}
231+
}
232+
214233
/// After propagating the 'explicit' bit on rules, process pairs of
215234
/// conflicting rules, marking one or both of the rules as conflicting,
216235
/// which instructs minimization to drop them.
@@ -409,6 +428,51 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID,
409428
return true;
410429
}
411430

431+
SmallVector<unsigned, 1>
432+
RewritePath::getRulesInEmptyContext(const MutableTerm &term,
433+
const RewriteSystem &system) {
434+
// Rules appearing in empty context (possibly more than once).
435+
llvm::SmallDenseSet<unsigned, 2> rulesInEmptyContext;
436+
// The number of times each rule appears (with or without context).
437+
llvm::SmallDenseMap<unsigned, unsigned, 2> ruleFrequency;
438+
439+
RewritePathEvaluator evaluator(term);
440+
for (auto step : Steps) {
441+
switch (step.Kind) {
442+
case RewriteStep::Rule: {
443+
if (!step.isInContext() && !evaluator.isInContext())
444+
rulesInEmptyContext.insert(step.getRuleID());
445+
446+
++ruleFrequency[step.getRuleID()];
447+
break;
448+
}
449+
450+
case RewriteStep::LeftConcreteProjection:
451+
case RewriteStep::Decompose:
452+
case RewriteStep::PrefixSubstitutions:
453+
case RewriteStep::Shift:
454+
case RewriteStep::Relation:
455+
case RewriteStep::DecomposeConcrete:
456+
case RewriteStep::RightConcreteProjection:
457+
break;
458+
}
459+
460+
evaluator.apply(step, system);
461+
}
462+
463+
// Collect all rules that we saw exactly once in empty context.
464+
SmallVector<unsigned, 1> rulesOnceInEmptyContext;
465+
for (auto rule : rulesInEmptyContext) {
466+
auto found = ruleFrequency.find(rule);
467+
assert(found != ruleFrequency.end());
468+
469+
if (found->second == 1)
470+
rulesOnceInEmptyContext.push_back(rule);
471+
}
472+
473+
return rulesOnceInEmptyContext;
474+
}
475+
412476
/// Find a rule to delete by looking through all loops for rewrite rules appearing
413477
/// once in empty context. Returns a pair consisting of a loop ID and a rule ID,
414478
/// otherwise returns None.
@@ -654,7 +718,7 @@ void RewriteSystem::performHomotopyReduction(
654718

655719
// If no redundant rules remain which can be eliminated by this pass, stop.
656720
if (!optPair)
657-
return;
721+
break;
658722

659723
unsigned loopID = optPair->first;
660724
unsigned ruleID = optPair->second;
@@ -679,6 +743,8 @@ void RewriteSystem::performHomotopyReduction(
679743

680744
deleteRule(ruleID, replacementPath);
681745
}
746+
747+
propagateRedundantRequirementIDs();
682748
}
683749

684750
void RewriteSystem::normalizeRedundantRules() {

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ static void desugarSameTypeRequirement(Type lhs, Type rhs, SourceLoc loc,
7777
Type sugaredFirstType) {
7878
if (firstType->isTypeParameter() && secondType->isTypeParameter()) {
7979
result.emplace_back(RequirementKind::SameType,
80-
firstType, secondType);
80+
sugaredFirstType, secondType);
8181
recordedRequirements = true;
8282
return true;
8383
}
8484

8585
if (firstType->isTypeParameter()) {
8686
result.emplace_back(RequirementKind::SameType,
87-
firstType, secondType);
87+
sugaredFirstType, secondType);
8888
recordedRequirements = true;
8989
return true;
9090
}
@@ -530,7 +530,7 @@ void swift::rewriting::realizeInheritedRequirements(
530530
/// \returns true if any errors were emitted, and false otherwise (including
531531
/// when only warnings were emitted).
532532
bool swift::rewriting::diagnoseRequirementErrors(
533-
ASTContext &ctx, SmallVectorImpl<RequirementError> &errors,
533+
ASTContext &ctx, ArrayRef<RequirementError> errors,
534534
bool allowConcreteGenericParams) {
535535
bool diagnosedError = false;
536536

@@ -1034,7 +1034,7 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
10341034

10351035
// Add rewrite rules for all top-level requirements.
10361036
for (const auto &req : requirements)
1037-
addRequirement(req, /*proto=*/nullptr);
1037+
addRequirement(req, /*proto=*/nullptr, /*requirementID=*/None);
10381038
}
10391039

10401040
void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements) {
@@ -1184,22 +1184,26 @@ swift::rewriting::getRuleForRequirement(const Requirement &req,
11841184
}
11851185

11861186
void RuleBuilder::addRequirement(const Requirement &req,
1187-
const ProtocolDecl *proto) {
1187+
const ProtocolDecl *proto,
1188+
Optional<unsigned> requirementID) {
11881189
if (Dump) {
11891190
llvm::dbgs() << "+ ";
11901191
req.dump(llvm::dbgs());
11911192
llvm::dbgs() << "\n";
11921193
}
11931194

1194-
RequirementRules.push_back(
1195+
auto rule =
11951196
getRuleForRequirement(req, proto, /*substitutions=*/None,
1196-
Context));
1197+
Context);
1198+
RequirementRules.push_back(
1199+
std::make_tuple(rule.first, rule.second, requirementID));
11971200
}
11981201

11991202
void RuleBuilder::addRequirement(const StructuralRequirement &req,
12001203
const ProtocolDecl *proto) {
1201-
// FIXME: Preserve source location information for diagnostics.
1202-
addRequirement(req.req.getCanonical(), proto);
1204+
WrittenRequirements.push_back(req);
1205+
unsigned requirementID = WrittenRequirements.size() - 1;
1206+
addRequirement(req.req.getCanonical(), proto, requirementID);
12031207
}
12041208

12051209
/// Lowers a protocol typealias to a rewrite rule.
@@ -1231,7 +1235,8 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
12311235
constraintTerm.add(Symbol::forConcreteType(concreteType, result, Context));
12321236
}
12331237

1234-
RequirementRules.emplace_back(subjectTerm, constraintTerm);
1238+
RequirementRules.emplace_back(subjectTerm, constraintTerm,
1239+
/*requirementID=*/None);
12351240
}
12361241

12371242
/// Record information about a protocol if we have no seen it yet.
@@ -1292,11 +1297,11 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
12921297
addRequirement(req, proto);
12931298

12941299
for (auto req : proto->getTypeAliasRequirements())
1295-
addRequirement(req.getCanonical(), proto);
1300+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
12961301
} else {
12971302
auto reqs = proto->getRequirementSignature();
12981303
for (auto req : reqs.getRequirements())
1299-
addRequirement(req.getCanonical(), proto);
1304+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
13001305
for (auto alias : reqs.getTypeAliases())
13011306
addTypeAlias(alias, proto);
13021307
}

lib/AST/RequirementMachine/RequirementLowering.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void realizeInheritedRequirements(TypeDecl *decl, Type type,
5959
SmallVectorImpl<RequirementError> &errors);
6060

6161
bool diagnoseRequirementErrors(ASTContext &ctx,
62-
SmallVectorImpl<RequirementError> &errors,
62+
ArrayRef<RequirementError> errors,
6363
bool allowConcreteGenericParams);
6464

6565
std::pair<MutableTerm, MutableTerm>
@@ -104,7 +104,12 @@ struct RuleBuilder {
104104

105105
/// New rules derived from requirements written by the user, which can be
106106
/// eliminated by homotopy reduction.
107-
std::vector<std::pair<MutableTerm, MutableTerm>> RequirementRules;
107+
std::vector<std::tuple<MutableTerm, MutableTerm, Optional<unsigned>>>
108+
RequirementRules;
109+
110+
/// Requirements written in source code. The requirement ID in the above
111+
/// \c RequirementRules vector is an index into this array.
112+
std::vector<StructuralRequirement> WrittenRequirements;
108113

109114
/// Enables debugging output. Controlled by the -dump-requirement-machine
110115
/// frontend flag.
@@ -123,7 +128,8 @@ struct RuleBuilder {
123128
void addAssociatedType(const AssociatedTypeDecl *type,
124129
const ProtocolDecl *proto);
125130
void addRequirement(const Requirement &req,
126-
const ProtocolDecl *proto);
131+
const ProtocolDecl *proto,
132+
Optional<unsigned> requirementID);
127133
void addRequirement(const StructuralRequirement &req,
128134
const ProtocolDecl *proto);
129135
void addTypeAlias(const ProtocolTypeAlias &alias,

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
9696
// Add the initial set of rewrite rules to the rewrite system.
9797
System.initialize(/*recordLoops=*/false,
9898
/*protos=*/ArrayRef<const ProtocolDecl *>(),
99+
std::move(builder.WrittenRequirements),
99100
std::move(builder.PermanentRules),
100101
std::move(builder.RequirementRules));
101102

@@ -138,6 +139,7 @@ RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
138139

139140
// Add the initial set of rewrite rules to the rewrite system.
140141
System.initialize(/*recordLoops=*/true, protos,
142+
std::move(builder.WrittenRequirements),
141143
std::move(builder.PermanentRules),
142144
std::move(builder.RequirementRules));
143145

@@ -185,6 +187,7 @@ RequirementMachine::initWithWrittenRequirements(
185187
// Add the initial set of rewrite rules to the rewrite system.
186188
System.initialize(/*recordLoops=*/true,
187189
/*protos=*/ArrayRef<const ProtocolDecl *>(),
190+
std::move(builder.WrittenRequirements),
188191
std::move(builder.PermanentRules),
189192
std::move(builder.RequirementRules));
190193

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
352352
std::unique_ptr<RequirementMachine> machine(new RequirementMachine(
353353
ctx.getRewriteContext()));
354354

355+
SmallVector<RequirementError, 4> errors;
356+
355357
auto status = machine->initWithProtocols(component);
356358
if (status.first != CompletionResult::Success) {
357359
// All we can do at this point is diagnose and give each protocol an empty
@@ -419,6 +421,13 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
419421
}
420422
}
421423

424+
if (ctx.LangOpts.RequirementMachineProtocolSignatures ==
425+
RequirementMachineMode::Enabled) {
426+
machine->System.computeRedundantRequirementDiagnostics(errors);
427+
diagnoseRequirementErrors(ctx, errors,
428+
/*allowConcreteGenericParams=*/false);
429+
}
430+
422431
// Return the result for the specific protocol this request was kicked off on.
423432
return *result;
424433
}
@@ -667,7 +676,9 @@ InferredGenericSignatureRequestRQM::evaluate(
667676

668677
if (ctx.LangOpts.RequirementMachineInferredSignatures ==
669678
RequirementMachineMode::Enabled) {
670-
hadError |= diagnoseRequirementErrors(ctx, errors, allowConcreteGenericParams);
679+
machine->System.computeRedundantRequirementDiagnostics(errors);
680+
hadError |= diagnoseRequirementErrors(ctx, errors,
681+
allowConcreteGenericParams);
671682
}
672683

673684
// FIXME: Handle allowConcreteGenericParams

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ static DebugOptions parseDebugFlags(StringRef debugFlags) {
4242
.Case("redundant-rules", DebugFlags::RedundantRules)
4343
.Case("redundant-rules-detail", DebugFlags::RedundantRulesDetail)
4444
.Case("concrete-contraction", DebugFlags::ConcreteContraction)
45+
.Case("propagate-requirement-ids", DebugFlags::PropagateRequirementIDs)
4546
.Default(None);
4647
if (!flag) {
4748
llvm::errs() << "Unknown debug flag in -debug-requirement-machine "

lib/AST/RequirementMachine/RewriteLoop.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,9 @@ class RewritePath {
409409

410410
bool replaceRuleWithPath(unsigned ruleID, const RewritePath &path);
411411

412+
SmallVector<unsigned, 1> getRulesInEmptyContext(const MutableTerm &term,
413+
const RewriteSystem &system);
414+
412415
void invert();
413416

414417
bool computeFreelyReducedForm();

0 commit comments

Comments
 (0)