Skip to content

RequirementMachine: Same-type requirements imply same-shape requirements #67064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions lib/AST/RequirementMachine/GenericSignatureQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,20 +721,27 @@ MutableTerm
RequirementMachine::getReducedShapeTerm(Type type) const {
assert(type->isParameterPack());

auto rootType = type->getRootGenericParam();
auto term = Context.getMutableTermForType(rootType->getCanonicalType(),
auto term = Context.getMutableTermForType(type->getCanonicalType(),
/*proto=*/nullptr);

// Append the 'shape' symbol to the term.
// From a type term T, form the shape term `T.[shape]`.
term.add(Symbol::forShape(Context));

// Compute the reduced shape term `T'.[shape]`.
System.simplify(term);
verify(term);

// Remove the 'shape' symbol from the term.
assert(term.back().getKind() == Symbol::Kind::Shape);
MutableTerm reducedTerm(term.begin(), term.end() - 1);
// Get the term T', which is the reduced shape of T.
if (term.size() != 2 ||
term[0].getKind() != Symbol::Kind::GenericParam ||
term[1].getKind() != Symbol::Kind::Shape) {
llvm::errs() << "Invalid reduced shape\n";
llvm::errs() << "Type: " << type << "\n";
llvm::errs() << "Term: " << term << "\n";
abort();
}

MutableTerm reducedTerm(term.begin(), term.end() - 1);
return reducedTerm;
}

Expand Down
7 changes: 2 additions & 5 deletions lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,7 @@ static void desugarConformanceRequirement(Requirement req,
desugarRequirement(subReq, loc, result, errors);
}

/// Desugar same-shape requirements by equating the shapes of the
/// root pack types, and diagnose shape requirements on non-pack
/// types.
/// Diagnose shape requirements on non-pack types.
static void desugarSameShapeRequirement(Requirement req, SourceLoc loc,
SmallVectorImpl<Requirement> &result,
SmallVectorImpl<RequirementError> &errors) {
Expand All @@ -376,8 +374,7 @@ static void desugarSameShapeRequirement(Requirement req, SourceLoc loc,
}

result.emplace_back(RequirementKind::SameShape,
req.getFirstType()->getRootGenericParam(),
req.getSecondType()->getRootGenericParam());
req.getFirstType(), req.getSecondType());
}

/// Convert a requirement where the subject type might not be a type parameter,
Expand Down
7 changes: 4 additions & 3 deletions lib/AST/RequirementMachine/RequirementMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ RequirementMachine::initWithProtocolSignatureRequirements(
///
/// Returns failure if completion fails within the configured number of steps.
std::pair<CompletionResult, unsigned>
RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
RequirementMachine::initWithGenericSignature(GenericSignature sig) {
Sig = sig;
Params.append(sig.getGenericParams().begin(),
sig.getGenericParams().end());
Expand All @@ -323,7 +323,8 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
// Collect the top-level requirements, and all transitively-referenced
// protocol requirement signatures.
RuleBuilder builder(Context, System.getReferencedProtocols());
builder.initWithGenericSignatureRequirements(sig.getRequirements());
builder.initWithGenericSignature(sig.getGenericParams(),
sig.getRequirements());

// Add the initial set of rewrite rules to the rewrite system.
System.initialize(/*recordLoops=*/false,
Expand Down Expand Up @@ -425,7 +426,7 @@ RequirementMachine::initWithWrittenRequirements(
// Collect the top-level requirements, and all transitively-referenced
// protocol requirement signatures.
RuleBuilder builder(Context, System.getReferencedProtocols());
builder.initWithWrittenRequirements(requirements);
builder.initWithWrittenRequirements(genericParams, requirements);

// Add the initial set of rewrite rules to the rewrite system.
System.initialize(/*recordLoops=*/true,
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/RequirementMachine/RequirementMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class RequirementMachine final {
friend class swift::AbstractGenericSignatureRequest;
friend class swift::InferredGenericSignatureRequest;

CanGenericSignature Sig;
GenericSignature Sig;
SmallVector<GenericTypeParamType *, 2> Params;

RewriteContext &Context;
Expand Down Expand Up @@ -95,7 +95,7 @@ class RequirementMachine final {
ArrayRef<const ProtocolDecl *> protos);

std::pair<CompletionResult, unsigned>
initWithGenericSignature(CanGenericSignature sig);
initWithGenericSignature(GenericSignature sig);

std::pair<CompletionResult, unsigned>
initWithProtocolWrittenRequirements(
Expand Down
36 changes: 20 additions & 16 deletions lib/AST/RequirementMachine/RewriteSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,16 @@ void RewriteSystem::recordRewriteLoop(MutableTerm basepoint,
return;

// Ignore the rewrite loop if it is not part of our minimization domain.
if (!isInMinimizationDomain(basepoint.getRootProtocol()))
//
// Completion might record a rewrite loop where the basepoint is just
// the term [shape]. In this case though, we know it's in our domain,
// since completion only checks local rules for overlap. Other callers
// of recordRewriteLoop() always pass in a valid basepoint, so we
// check.
if (basepoint[0].getKind() != Symbol::Kind::Shape &&
!isInMinimizationDomain(basepoint.getRootProtocol())) {
return;
}

Loops.push_back(loop);
}
Expand Down Expand Up @@ -555,11 +563,6 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
}

// A shape symbol must follow a generic param symbol
if (symbol.getKind() == Symbol::Kind::Shape) {
ASSERT_RULE(index > 0 && lhs[index - 1].getKind() == Symbol::Kind::GenericParam);
}

if (!rule.isLHSSimplified() &&
index != lhs.size() - 1) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteConformance);
Expand Down Expand Up @@ -602,15 +605,10 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Superclass);
ASSERT_RULE(symbol.getKind() != Symbol::Kind::ConcreteType);

if (index != lhs.size() - 1) {
if (index != rhs.size() - 1) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Shape);
}

// A shape symbol must follow a generic param symbol
if (symbol.getKind() == Symbol::Kind::Shape) {
ASSERT_RULE(index > 0 && rhs[index - 1].getKind() == Symbol::Kind::GenericParam);
}

// Completion can introduce a rule of the form
//
// (T.[P] => T.[concrete: C : P])
Expand All @@ -635,10 +633,15 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
}
}

auto lhsDomain = lhs.getRootProtocol();
auto rhsDomain = rhs.getRootProtocol();

ASSERT_RULE(lhsDomain == rhsDomain);
if (rhs.size() == 1 && rhs[0].getKind() == Symbol::Kind::Shape) {
// We can have a rule like T.[shape] => [shape].
ASSERT_RULE(lhs.back().getKind() == Symbol::Kind::Shape);
} else {
// Otherwise, LHS and RHS must have the same domain.
auto lhsDomain = lhs.getRootProtocol();
auto rhsDomain = rhs.getRootProtocol();
ASSERT_RULE(lhsDomain == rhsDomain);
}
}

#undef ASSERT_RULE
Expand Down Expand Up @@ -709,6 +712,7 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const {
loop.dump(out, *this);
out << "\n";
}
out << "}\n";
}
if (!WrittenRequirements.empty()) {
out << "Written requirements: {\n";
Expand Down
80 changes: 79 additions & 1 deletion lib/AST/RequirementMachine/RuleBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ using namespace rewriting;

/// For building a rewrite system for a generic signature from canonical
/// requirements.
void RuleBuilder::initWithGenericSignatureRequirements(
void RuleBuilder::initWithGenericSignature(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<Requirement> requirements) {
assert(!Initialized);
Initialized = 1;
Expand All @@ -47,6 +48,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
}

collectRulesFromReferencedProtocols();
collectPackShapeRules(genericParams);

// Add rewrite rules for all top-level requirements.
for (const auto &req : requirements)
Expand All @@ -56,6 +58,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
/// For building a rewrite system for a generic signature from user-written
/// requirements.
void RuleBuilder::initWithWrittenRequirements(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<StructuralRequirement> requirements) {
assert(!Initialized);
Initialized = 1;
Expand All @@ -68,6 +71,7 @@ void RuleBuilder::initWithWrittenRequirements(
}

collectRulesFromReferencedProtocols();
collectPackShapeRules(genericParams);

// Add rewrite rules for all top-level requirements.
for (const auto &req : requirements)
Expand Down Expand Up @@ -488,3 +492,77 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
localRules.end());
}
}

void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams) {
if (Dump) {
llvm::dbgs() << "adding shape rules\n";
}

if (!llvm::any_of(genericParams,
[](GenericTypeParamType *t) {
return t->isParameterPack();
})) {
return;
}

// Each non-pack generic parameter is part of the "scalar shape class", represented
// by the empty term.
for (auto *genericParam : genericParams) {
if (genericParam->isParameterPack())
continue;

// Add the rule (τ_d_i.[shape] => [shape]).
MutableTerm lhs;
lhs.add(Symbol::forGenericParam(
cast<GenericTypeParamType>(genericParam->getCanonicalType()), Context));
lhs.add(Symbol::forShape(Context));

MutableTerm rhs;
rhs.add(Symbol::forShape(Context));

PermanentRules.emplace_back(lhs, rhs);
}

// A member type T.[P:A] is part of the same shape class as its base type T.
llvm::DenseSet<Symbol> visited;

auto addMemberShapeRule = [&](const ProtocolDecl *proto, AssociatedTypeDecl *assocType) {
auto symbol = Symbol::forAssociatedType(proto, assocType->getName(), Context);
if (!visited.insert(symbol).second)
return;

// Add the rule ([P:A].[shape] => [shape]).
MutableTerm lhs;
lhs.add(symbol);
lhs.add(Symbol::forShape(Context));

MutableTerm rhs;
rhs.add(Symbol::forShape(Context));

// Consider it an imported rule, since it is not part of our minimization
// domain. It would be more logical if we added these in the protocol component
// machine for this protocol, but instead we add them in the "leaf" generic
// signature machine. This avoids polluting machines that do not involve
// parameter packs with these extra rules, which would otherwise just slow
// things down.
Rule rule(Term::get(lhs, Context), Term::get(rhs, Context));
rule.markPermanent();
ImportedRules.push_back(rule);
};

for (auto *proto : ProtocolsToImport) {
if (Dump) {
llvm::dbgs() << "adding member shape rules for protocol " << proto->getName() << "\n";
}

for (auto *assocType : proto->getAssociatedTypeMembers()) {
addMemberShapeRule(proto, assocType);
}

for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
addMemberShapeRule(proto, assocType);
}
}
}
}
7 changes: 5 additions & 2 deletions lib/AST/RequirementMachine/RuleBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ struct RuleBuilder {
Initialized = 0;
}

void initWithGenericSignatureRequirements(ArrayRef<Requirement> requirements);
void initWithWrittenRequirements(ArrayRef<StructuralRequirement> requirements);
void initWithGenericSignature(ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<Requirement> requirements);
void initWithWrittenRequirements(ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<StructuralRequirement> requirements);
void initWithProtocolSignatureRequirements(ArrayRef<const ProtocolDecl *> proto);
void initWithProtocolWrittenRequirements(
ArrayRef<const ProtocolDecl *> component,
Expand All @@ -106,6 +108,7 @@ struct RuleBuilder {
ArrayRef<Term> substitutions);
void addReferencedProtocol(const ProtocolDecl *proto);
void collectRulesFromReferencedProtocols();
void collectPackShapeRules(ArrayRef<GenericTypeParamType *> genericParams);

private:
void addPermanentProtocolRules(const ProtocolDecl *proto);
Expand Down
7 changes: 1 addition & 6 deletions lib/AST/RequirementMachine/Symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,12 +700,7 @@ void Symbol::dump(llvm::raw_ostream &out) const {
}

case Kind::GenericParam: {
auto *gp = getGenericParam();
if (gp->isParameterPack()) {
out << "(" << Type(gp) << "…)";
} else {
out << Type(gp);
}
out << Type(getGenericParam());
return;
}

Expand Down
20 changes: 19 additions & 1 deletion test/Generics/pack-shape-requirements.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -disable-availability-checking 2>&1 | %FileCheck %s

protocol P {
associatedtype A
associatedtype A: P
}

// CHECK-LABEL: inferSameShape(ts:us:)
Expand Down Expand Up @@ -65,3 +65,21 @@ struct Ts<each T> {
func expandedParameters<each T, each Result>(_ t: repeat each T, transform: repeat (each T) -> each Result) -> (repeat each Result) {
fatalError()
}


//////
///
/// Same-type requirements should imply same-shape requirements.
///
//////

// CHECK-LABEL: sameType1
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : P, repeat each U : P, repeat (each T).[P]A == (each U).[P]A>
func sameType1<each T, each U>(_: repeat (each T, each U)) where repeat each T: P, repeat each U: P, repeat each T.A == each U.A {}

// Make sure inherited associated types are handled
protocol Q: P where A: Q {}

// CHECK-LABEL: sameType2
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : Q, repeat each U : Q, repeat (each T).[P]A.[P]A == (each U).[P]A.[P]A>
func sameType2<each T, each U>(_: repeat (each T, each U)) where repeat each T: Q, repeat each U: Q, repeat each T.A.A == each U.A.A {}
18 changes: 18 additions & 0 deletions validation-test/compiler_crashers_2_fixed/rdar108319167.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %target-swift-frontend -emit-ir %s

public protocol P {}

public protocol Q {
associatedtype A: P
}

public func f<T: P>(_: T) {}

public func foo1<each T: Q, each U>(t: repeat each T, u: repeat each U)
where repeat (each U) == (each T).A {
repeat f(each u)
}

public func foo2<each T: Q>(t: repeat each T, u: repeat each T.A) {
repeat f(each u)
}