Skip to content

Commit 6c0ccfc

Browse files
authored
Merge pull request #41838 from slavapestov/rqm-conditional-requirement-inference-vs-rule-sharing
RequirementMachine: Fix bad interaction between rule sharing and conditional requirement inference
2 parents b17b1a9 + 6e6c8c2 commit 6c0ccfc

16 files changed

+907
-647
lines changed

lib/AST/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ add_swift_host_library(swiftAST STATIC
8383
RequirementMachine/InterfaceType.cpp
8484
RequirementMachine/KnuthBendix.cpp
8585
RequirementMachine/MinimalConformances.cpp
86+
RequirementMachine/NameLookup.cpp
8687
RequirementMachine/NormalizeRewritePath.cpp
8788
RequirementMachine/PropertyMap.cpp
8889
RequirementMachine/PropertyRelations.cpp
@@ -95,6 +96,7 @@ add_swift_host_library(swiftAST STATIC
9596
RequirementMachine/RewriteLoop.cpp
9697
RequirementMachine/RewriteSystem.cpp
9798
RequirementMachine/Rule.cpp
99+
RequirementMachine/RuleBuilder.cpp
98100
RequirementMachine/SimplifySubstitutions.cpp
99101
RequirementMachine/Symbol.cpp
100102
RequirementMachine/Term.cpp

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
#include "swift/AST/Types.h"
147147
#include "llvm/ADT/DenseMap.h"
148148
#include "llvm/ADT/SmallVector.h"
149+
#include "NameLookup.h"
149150
#include "RequirementLowering.h"
150151

151152
using namespace swift;
@@ -182,30 +183,6 @@ class ConcreteContraction {
182183

183184
} // end namespace
184185

185-
/// Find the most canonical member type of \p decl named \p name, using the
186-
/// canonical type order.
187-
static TypeDecl *lookupConcreteNestedType(ModuleDecl *module,
188-
NominalTypeDecl *decl,
189-
Identifier name) {
190-
SmallVector<ValueDecl *, 2> foundMembers;
191-
module->lookupQualified(
192-
decl, DeclNameRef(name),
193-
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
194-
foundMembers);
195-
196-
SmallVector<TypeDecl *, 2> concreteDecls;
197-
for (auto member : foundMembers)
198-
concreteDecls.push_back(cast<TypeDecl>(member));
199-
200-
if (concreteDecls.empty())
201-
return nullptr;
202-
203-
return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
204-
[](TypeDecl *type1, TypeDecl *type2) {
205-
return TypeDecl::compare(type1, type2) < 0;
206-
});
207-
}
208-
209186
/// A re-implementation of Type::subst() that also handles unresolved
210187
/// DependentMemberTypes by performing name lookup into the base type.
211188
///
@@ -267,12 +244,12 @@ Optional<Type> ConcreteContraction::substTypeParameter(
267244
return None;
268245
}
269246

270-
auto *module = decl->getParentModule();
271-
272247
// An unresolved DependentMemberType stores an identifier. Handle this
273248
// by performing a name lookup into the base type.
274-
auto *typeDecl = lookupConcreteNestedType(module, decl,
275-
memberType->getName());
249+
SmallVector<TypeDecl *> concreteDecls;
250+
lookupConcreteNestedType(decl, memberType->getName(), concreteDecls);
251+
252+
auto *typeDecl = findBestConcreteNestedType(concreteDecls);
276253
if (typeDecl == nullptr) {
277254
// The base type doesn't contain a member type with this name, in which
278255
// case the requirement remains unsubstituted.
@@ -285,7 +262,7 @@ Optional<Type> ConcreteContraction::substTypeParameter(
285262

286263
// Substitute the base type into the member type.
287264
auto subMap = (*substBaseType)->getContextSubstitutionMap(
288-
module, typeDecl->getDeclContext());
265+
decl->getParentModule(), typeDecl->getDeclContext());
289266
return typeDecl->getDeclaredInterfaceType().subst(subMap);
290267
}
291268

@@ -531,6 +508,12 @@ bool ConcreteContraction::performConcreteContraction(
531508
// Phase 2: Replace each concretely-conforming generic parameter with its
532509
// concrete type.
533510
for (auto req : requirements) {
511+
if (Debug) {
512+
llvm::dbgs() << "@ Original requirement: ";
513+
req.req.dump(llvm::dbgs());
514+
llvm::dbgs() << "\n";
515+
}
516+
534517
// Substitute the requirement.
535518
Optional<Requirement> substReq = substRequirement(req.req);
536519

lib/AST/RequirementMachine/ConcreteTypeWitness.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <vector>
2525
#include "PropertyMap.h"
2626
#include "RequirementLowering.h"
27+
#include "RuleBuilder.h"
2728

2829
using namespace swift;
2930
using namespace rewriting;
@@ -200,7 +201,9 @@ void PropertyMap::concretizeTypeWitnessInConformance(
200201
AssociatedTypeDecl *assocType) const {
201202
auto concreteType = concreteConformanceSymbol.getConcreteType();
202203
auto substitutions = concreteConformanceSymbol.getSubstitutions();
203-
auto *proto = concreteConformanceSymbol.getProtocol();
204+
205+
auto *proto = assocType->getProtocol();
206+
assert(proto == concreteConformanceSymbol.getProtocol());
204207

205208
if (Debug.contains(DebugFlags::ConcretizeNestedTypes)) {
206209
llvm::dbgs() << "^^ " << "Looking up type witness for "
@@ -529,8 +532,8 @@ void PropertyMap::inferConditionalRequirements(
529532
return;
530533

531534
SmallVector<Requirement, 2> desugaredRequirements;
532-
// FIXME: Store errors in the rewrite system to be diagnosed
533-
// from the top-level generic signature requests.
535+
536+
// FIXME: Do we need to diagnose these errors?
534537
SmallVector<RequirementError, 2> errors;
535538

536539
// First, desugar all conditional requirements.
@@ -545,47 +548,17 @@ void PropertyMap::inferConditionalRequirements(
545548
}
546549

547550
// Now, convert desugared conditional requirements to rules.
548-
for (auto req : desugaredRequirements) {
549-
if (Debug.contains(DebugFlags::ConditionalRequirements)) {
550-
llvm::dbgs() << "@@@ Desugared requirement: ";
551-
req.dump(llvm::dbgs());
552-
llvm::dbgs() << "\n";
553-
}
554-
555-
if (req.getKind() == RequirementKind::Conformance) {
556-
auto *proto = req.getProtocolDecl();
557-
558-
// If we haven't seen this protocol before, add rules for its
559-
// requirements.
560-
if (!System.isKnownProtocol(proto)) {
561-
if (Debug.contains(DebugFlags::ConditionalRequirements)) {
562-
llvm::dbgs() << "@@@ Unknown protocol: "<< proto->getName() << "\n";
563-
}
564551

565-
RuleBuilder builder(Context, System.getReferencedProtocols());
566-
builder.addReferencedProtocol(proto);
567-
builder.collectRulesFromReferencedProtocols();
552+
// This will update System.getReferencedProtocols() with any new
553+
// protocols that were imported.
554+
RuleBuilder builder(Context, System.getReferencedProtocols());
555+
builder.initWithConditionalRequirements(desugaredRequirements,
556+
substitutions);
568557

569-
for (const auto &rule : builder.PermanentRules)
570-
System.addPermanentRule(rule.first, rule.second);
558+
assert(builder.PermanentRules.empty());
559+
assert(builder.WrittenRequirements.empty());
571560

572-
for (const auto &rule : builder.RequirementRules) {
573-
auto lhs = std::get<0>(rule);
574-
auto rhs = std::get<1>(rule);
575-
System.addExplicitRule(lhs, rhs, /*requirementID=*/None);
576-
}
577-
}
578-
}
579-
580-
auto pair = getRuleForRequirement(req.getCanonical(), /*proto=*/nullptr,
581-
substitutions, Context);
582-
583-
if (Debug.contains(DebugFlags::ConditionalRequirements)) {
584-
llvm::dbgs() << "@@@ Induced rule from conditional requirement: "
585-
<< pair.first << " => " << pair.second << "\n";
586-
}
587-
588-
// FIXME: Do we need a rewrite path here?
589-
(void) System.addRule(pair.first, pair.second);
590-
}
561+
System.addRules(std::move(builder.ImportedRules),
562+
std::move(builder.PermanentRules),
563+
std::move(builder.RequirementRules));
591564
}

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
// Use those methods instead of calling into the RequirementMachine directly.
1515
//
1616
//===----------------------------------------------------------------------===//
17+
1718
#include "swift/AST/ASTContext.h"
1819
#include "swift/AST/Decl.h"
1920
#include "swift/AST/GenericSignature.h"
2021
#include "swift/AST/Module.h"
21-
#include "llvm/ADT/TinyPtrVector.h"
2222
#include <vector>
23-
23+
#include "NameLookup.h"
2424
#include "RequirementMachine.h"
2525

2626
using namespace swift;
@@ -609,26 +609,6 @@ RequirementMachine::getConformanceAccessPath(Type type,
609609
}
610610
}
611611

612-
static void lookupConcreteNestedType(NominalTypeDecl *decl,
613-
Identifier name,
614-
SmallVectorImpl<TypeDecl *> &concreteDecls) {
615-
SmallVector<ValueDecl *, 2> foundMembers;
616-
decl->getParentModule()->lookupQualified(
617-
decl, DeclNameRef(name),
618-
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
619-
foundMembers);
620-
for (auto member : foundMembers)
621-
concreteDecls.push_back(cast<TypeDecl>(member));
622-
}
623-
624-
static TypeDecl *
625-
findBestConcreteNestedType(SmallVectorImpl<TypeDecl *> &concreteDecls) {
626-
return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
627-
[](TypeDecl *type1, TypeDecl *type2) {
628-
return TypeDecl::compare(type1, type2) < 0;
629-
});
630-
}
631-
632612
TypeDecl *
633613
RequirementMachine::lookupNestedType(Type depType, Identifier name) const {
634614
auto term = Context.getMutableTermForType(depType->getCanonicalType(),

lib/AST/RequirementMachine/KnuthBendix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount,
391391

392392
for (const auto &pair : resolvedCriticalPairs) {
393393
// Check if we've already done too much work.
394-
if (Rules.size() > maxRuleCount)
394+
if (getLocalRules().size() > maxRuleCount)
395395
return std::make_pair(CompletionResult::MaxRuleCount, Rules.size() - 1);
396396

397397
if (!addRule(pair.LHS, pair.RHS, &pair.Path))
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===--- NameLookup.cpp - Name lookup utilities ---------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "NameLookup.h"
14+
#include "swift/AST/Decl.h"
15+
#include "swift/AST/Module.h"
16+
#include "llvm/ADT/SmallVector.h"
17+
#include <algorithm>
18+
19+
using namespace swift;
20+
using namespace rewriting;
21+
22+
void
23+
swift::rewriting::lookupConcreteNestedType(
24+
NominalTypeDecl *decl,
25+
Identifier name,
26+
SmallVectorImpl<TypeDecl *> &concreteDecls) {
27+
SmallVector<ValueDecl *, 2> foundMembers;
28+
decl->getParentModule()->lookupQualified(
29+
decl, DeclNameRef(name),
30+
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
31+
foundMembers);
32+
for (auto member : foundMembers)
33+
concreteDecls.push_back(cast<TypeDecl>(member));
34+
}
35+
36+
TypeDecl *
37+
swift::rewriting::findBestConcreteNestedType(
38+
SmallVectorImpl<TypeDecl *> &concreteDecls) {
39+
if (concreteDecls.empty())
40+
return nullptr;
41+
42+
return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
43+
[](TypeDecl *type1, TypeDecl *type2) {
44+
return TypeDecl::compare(type1, type2) < 0;
45+
});
46+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===--- NameLookup.h - Name lookup utilities -------------------*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_RQM_NAMELOOKUP_H
14+
#define SWIFT_RQM_NAMELOOKUP_H
15+
16+
#include "llvm/ADT/SmallVector.h"
17+
18+
namespace swift {
19+
20+
class Identifier;
21+
class NominalTypeDecl;
22+
class TypeDecl;
23+
24+
namespace rewriting {
25+
26+
void lookupConcreteNestedType(
27+
NominalTypeDecl *decl,
28+
Identifier name,
29+
llvm::SmallVectorImpl<TypeDecl *> &concreteDecls);
30+
31+
TypeDecl *findBestConcreteNestedType(
32+
llvm::SmallVectorImpl<TypeDecl *> &concreteDecls);
33+
34+
} // end namespace rewriting
35+
36+
} // end namespace swift
37+
38+
#endif

0 commit comments

Comments
 (0)