Skip to content

RequirementMachine: Implement -requirement-machine-protocol-signatures=verify #40011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4168,6 +4168,7 @@ class ProtocolDecl final : public NominalTypeDecl {
friend class StructuralRequirementsRequest;
friend class ProtocolDependenciesRequest;
friend class RequirementSignatureRequest;
friend class RequirementSignatureRequestRQM;
friend class ProtocolRequiresClassRequest;
friend class ExistentialConformsToSelfRequest;
friend class InheritedProtocolsRequest;
Expand Down
20 changes: 20 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,26 @@ class ProtocolDependenciesRequest :
bool isCached() const { return true; }
};

/// Compute the requirements that describe a protocol using the
/// RequirementMachine.
class RequirementSignatureRequestRQM :
public SimpleRequest<RequirementSignatureRequestRQM,
ArrayRef<Requirement>(ProtocolDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
ArrayRef<Requirement>
evaluate(Evaluator &evaluator, ProtocolDecl *proto) const;

public:
bool isCached() const { return true; }
};

/// Compute the requirements that describe a protocol.
class RequirementSignatureRequest :
public SimpleRequest<RequirementSignatureRequest,
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ SWIFT_REQUEST(TypeChecker, StructuralRequirementsRequest,
SWIFT_REQUEST(TypeChecker, ProtocolDependenciesRequest,
ArrayRef<ProtocolDecl *>(ProtocolDecl *), Cached,
HasNearestLocation)
SWIFT_REQUEST(TypeChecker, RequirementSignatureRequestRQM,
ArrayRef<Requirement>(ProtocolDecl *), Cached,
NoLocationInfo)
SWIFT_REQUEST(TypeChecker, RequirementSignatureRequest,
ArrayRef<Requirement>(ProtocolDecl *), SeparatelyCached,
NoLocationInfo)
Expand Down
97 changes: 95 additions & 2 deletions lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@
using namespace swift;
using llvm::DenseMap;

/// Define this to 1 to enable expensive assertions.
#define SWIFT_GSB_EXPENSIVE_ASSERTIONS 0
#define DEBUG_TYPE "Serialization"

STATISTIC(NumLazyRequirementSignaturesLoaded,
"# of lazily-deserialized requirement signatures loaded");

#undef DEBUG_TYPE

namespace {
typedef GenericSignatureBuilder::RequirementSource RequirementSource;
Expand Down Expand Up @@ -8702,3 +8706,92 @@ InferredGenericSignatureRequest::evaluate(
allowConcreteGenericParams);
return GenericSignatureWithError(result, hadError);
}

ArrayRef<Requirement>
RequirementSignatureRequest::evaluate(Evaluator &evaluator,
ProtocolDecl *proto) const {
ASTContext &ctx = proto->getASTContext();

// First check if we have a deserializable requirement signature.
if (proto->hasLazyRequirementSignature()) {
++NumLazyRequirementSignaturesLoaded;
// FIXME: (transitional) increment the redundant "always-on" counter.
if (ctx.Stats)
++ctx.Stats->getFrontendCounters().NumLazyRequirementSignaturesLoaded;

auto contextData = static_cast<LazyProtocolData *>(
ctx.getOrCreateLazyContextData(proto, nullptr));

SmallVector<Requirement, 8> requirements;
contextData->loader->loadRequirementSignature(
proto, contextData->requirementSignatureData, requirements);
if (requirements.empty())
return None;
return ctx.AllocateCopy(requirements);
}

auto buildViaGSB = [&]() {
GenericSignatureBuilder builder(proto->getASTContext());

// Add all of the generic parameters.
for (auto gp : *proto->getGenericParams())
builder.addGenericParameter(gp);

// Add the conformance of 'self' to the protocol.
auto selfType =
proto->getSelfInterfaceType()->castTo<GenericTypeParamType>();
auto requirement =
Requirement(RequirementKind::Conformance, selfType,
proto->getDeclaredInterfaceType());

builder.addRequirement(
requirement,
GenericSignatureBuilder::RequirementSource::forRequirementSignature(
builder, selfType, proto),
nullptr);

auto reqSignature = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/false,
/*requirementSignatureSelfProto=*/proto);
return reqSignature.getRequirements();
};

auto buildViaRQM = [&]() {
return evaluateOrDefault(
ctx.evaluator,
RequirementSignatureRequestRQM{const_cast<ProtocolDecl *>(proto)},
ArrayRef<Requirement>());
};

switch (ctx.LangOpts.RequirementMachineProtocolSignatures) {
case RequirementMachineMode::Disabled:
return buildViaGSB();

case RequirementMachineMode::Enabled:
return buildViaRQM();

case RequirementMachineMode::Verify: {
auto rqmResult = buildViaRQM();
auto gsbResult = buildViaGSB();

if (rqmResult.size() != gsbResult.size() ||
!std::equal(rqmResult.begin(), rqmResult.end(),
gsbResult.begin())) {
llvm::errs() << "RequirementMachine protocol signature minimization is broken:\n";
llvm::errs() << "Protocol: " << proto->getName() << "\n";

auto rqmSig = GenericSignature::get(
proto->getGenericSignature().getGenericParams(), rqmResult);
llvm::errs() << "RequirementMachine says: " << rqmSig << "\n";

auto gsbSig = GenericSignature::get(
proto->getGenericSignature().getGenericParams(), gsbResult);
llvm::errs() << "GenericSignatureBuilder says: " << gsbSig << "\n";

abort();
}

return gsbResult;
}
}
}
23 changes: 20 additions & 3 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID,

SmallVector<RewriteStep, 4> newSteps;

// Keep track of Decompose/Compose pairs. Any rewrite steps in
// between do not need to be re-contextualized, since they
// operate on new terms that were pushed on the stack by the
// Compose operation.
unsigned decomposeCount = 0;

for (const auto &step : Steps) {
switch (step.Kind) {
case RewriteStep::ApplyRewriteRule: {
Expand All @@ -202,13 +208,24 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID,
}

auto adjustStep = [&](RewriteStep newStep) {
newStep.StartOffset += step.StartOffset;
bool inverse = newStep.Inverse ^ step.Inverse;

if (newStep.Kind == RewriteStep::Decompose && inverse) {
assert(decomposeCount > 0);
--decomposeCount;
}

if (newStep.Kind == RewriteStep::ApplyRewriteRule)
if (decomposeCount == 0) {
newStep.StartOffset += step.StartOffset;
newStep.EndOffset += step.EndOffset;
}

newStep.Inverse ^= step.Inverse;
newStep.Inverse = inverse;
newSteps.push_back(newStep);

if (newStep.Kind == RewriteStep::Decompose && !inverse) {
++decomposeCount;
}
};

if (step.Inverse) {
Expand Down
44 changes: 22 additions & 22 deletions lib/AST/RequirementMachine/PropertyMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,6 @@ void PropertyMap::addProperty(
props->addProperty(property, Context,
inducedRules, Debug.contains(DebugFlags::ConcreteUnification));
}
void PropertyMap::dump(llvm::raw_ostream &out) const {
out << "Property map: {\n";
for (const auto &props : Entries) {
out << " ";
props->dump(out);
out << "\n";
}
out << "}\n";
}

/// Build the property map from all rules of the form T.[p] => T, where
/// [p] is a property symbol.
Expand All @@ -318,19 +309,18 @@ void PropertyMap::dump(llvm::raw_ostream &out) const {
/// left hand side has a length exceeding \p maxDepth.
///
/// Otherwise, the status is CompletionResult::Success.
std::pair<RewriteSystem::CompletionResult, unsigned>
RewriteSystem::buildPropertyMap(PropertyMap &map,
unsigned maxIterations,
unsigned maxDepth) {
map.clear();
std::pair<CompletionResult, unsigned>
PropertyMap::buildPropertyMap(unsigned maxIterations,
unsigned maxDepth) {
clear();

// PropertyMap::addRule() requires that shorter rules are added
// before longer rules, so that it can perform lookups on suffixes and call
// PropertyBag::copyPropertiesFrom(). However, we don't have to perform a
// full sort by term order here; a bucket sort by term length suffices.
SmallVector<std::vector<std::pair<Term, Symbol>>, 4> properties;

for (const auto &rule : Rules) {
for (const auto &rule : System.getRules()) {
if (rule.isSimplified())
continue;

Expand All @@ -355,34 +345,44 @@ RewriteSystem::buildPropertyMap(PropertyMap &map,

for (const auto &bucket : properties) {
for (auto pair : bucket) {
map.addProperty(pair.first, pair.second, inducedRules);
addProperty(pair.first, pair.second, inducedRules);
}
}

// We collect terms with fully concrete types so that we can re-use them
// to tie off recursion in the next step.
map.computeConcreteTypeInDomainMap();
computeConcreteTypeInDomainMap();

// Now, we merge concrete type rules with conformance rules, by adding
// relations between associated type members of type parameters with
// the concrete type witnesses in the concrete type's conformance.
map.concretizeNestedTypesFromConcreteParents(inducedRules);
concretizeNestedTypesFromConcreteParents(inducedRules);

// Some of the induced rules might be trivial; only count the induced rules
// where the left hand side is not already equivalent to the right hand side.
unsigned addedNewRules = 0;
for (auto pair : inducedRules) {
if (addRule(pair.first, pair.second)) {
if (System.addRule(pair.first, pair.second)) {
++addedNewRules;

const auto &newRule = Rules.back();
if (newRule.getLHS().size() > maxDepth)
const auto &newRule = System.getRules().back();
if (newRule.getDepth() > maxDepth)
return std::make_pair(CompletionResult::MaxDepth, addedNewRules);
}
}

if (Rules.size() > maxIterations)
if (System.getRules().size() > maxIterations)
return std::make_pair(CompletionResult::MaxIterations, addedNewRules);

return std::make_pair(CompletionResult::Success, addedNewRules);
}

void PropertyMap::dump(llvm::raw_ostream &out) const {
out << "Property map: {\n";
for (const auto &props : Entries) {
out << " ";
props->dump(out);
out << "\n";
}
out << "}\n";
}
6 changes: 5 additions & 1 deletion lib/AST/RequirementMachine/PropertyMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,13 @@ class PropertyMap {

PropertyBag *lookUpProperties(const MutableTerm &key) const;

std::pair<CompletionResult, unsigned>
buildPropertyMap(unsigned maxIterations,
unsigned maxDepth);

void dump(llvm::raw_ostream &out) const;

private:
void clear();
void addProperty(Term key, Symbol property,
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules);
Expand All @@ -173,7 +178,6 @@ class PropertyMap {
void concretizeNestedTypesFromConcreteParents(
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) const;

private:
void concretizeNestedTypesFromConcreteParent(
Term key, RequirementKind requirementKind,
CanType concreteType, ArrayRef<Term> substitutions,
Expand Down
9 changes: 4 additions & 5 deletions lib/AST/RequirementMachine/RequirementMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,16 +528,16 @@ void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy)
// Check for failure.
auto checkCompletionResult = [&]() {
switch (result.first) {
case RewriteSystem::CompletionResult::Success:
case CompletionResult::Success:
break;

case RewriteSystem::CompletionResult::MaxIterations:
case CompletionResult::MaxIterations:
llvm::errs() << "Generic signature " << Sig
<< " exceeds maximum completion step count\n";
System.dump(llvm::errs());
abort();

case RewriteSystem::CompletionResult::MaxDepth:
case CompletionResult::MaxDepth:
llvm::errs() << "Generic signature " << Sig
<< " exceeds maximum completion depth\n";
System.dump(llvm::errs());
Expand All @@ -553,8 +553,7 @@ void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy)
// Build the property map, which also performs concrete term
// unification; if this added any new rules, run the completion
// procedure again.
result = System.buildPropertyMap(
Map,
result = Map.buildPropertyMap(
RequirementMachineStepLimit,
RequirementMachineDepthLimit);

Expand Down
Loading