Skip to content

Commit 204d496

Browse files
authored
Merge pull request #63889 from xedin/cache-conformance-checks-in-cs
[ConstraintSystem] Add cache for conformance lookups
2 parents 8dbde04 + 32d2651 commit 204d496

File tree

8 files changed

+63
-54
lines changed

8 files changed

+63
-54
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,23 +195,22 @@ struct LiteralRequirement {
195195
/// \param canBeNil The flag that determines whether given type
196196
/// variable requires all of its bindings to be optional.
197197
///
198-
/// \param useDC The declaration context in which this literal
199-
/// requirement is used.
198+
/// \param CS The constraint system this literal requirement belongs to.
200199
///
201200
/// \returns a pair of bool and a type:
202201
/// - bool, true if binding covers given literal protocol;
203202
/// - type, non-null if binding type has to be adjusted
204203
/// to cover given literal protocol;
205204
std::pair<bool, Type> isCoveredBy(const PotentialBinding &binding,
206205
bool canBeNil,
207-
DeclContext *useDC) const;
206+
ConstraintSystem &CS) const;
208207

209208
/// Determines whether literal protocol associated with this
210209
/// meta-information is viable for inclusion as a defaultable binding.
211210
bool viableAsBinding() const { return !isCovered() && hasDefaultType(); }
212211

213212
private:
214-
bool isCoveredBy(Type type, DeclContext *useDC) const;
213+
bool isCoveredBy(Type type, ConstraintSystem &CS) const;
215214
};
216215

217216
struct PotentialBindings {

include/swift/Sema/ConstraintSystem.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,10 @@ class ConstraintSystem {
22502250
llvm::SmallMapVector<ConstraintLocator *, ArrayRef<OpenedType>, 4>
22512251
OpenedTypes;
22522252

2253+
/// A dictionary of all conformances that have been looked up by the solver.
2254+
llvm::DenseMap<std::pair<TypeBase *, ProtocolDecl *>, ProtocolConformanceRef>
2255+
Conformances;
2256+
22532257
/// The list of all generic requirements fixed along the current
22542258
/// solver path.
22552259
using FixedRequirement =
@@ -3946,7 +3950,7 @@ class ConstraintSystem {
39463950
///
39473951
/// \param wantRValue Whether this routine should look through
39483952
/// lvalues at each step.
3949-
Type getFixedTypeRecursive(Type type, bool wantRValue) const {
3953+
Type getFixedTypeRecursive(Type type, bool wantRValue) {
39503954
TypeMatchOptions flags = llvm::None;
39513955
return getFixedTypeRecursive(type, flags, wantRValue);
39523956
}
@@ -3964,7 +3968,7 @@ class ConstraintSystem {
39643968
/// \param wantRValue Whether this routine should look through
39653969
/// lvalues at each step.
39663970
Type getFixedTypeRecursive(Type type, TypeMatchOptions &flags,
3967-
bool wantRValue) const;
3971+
bool wantRValue);
39683972

39693973
/// Determine whether the given type variable occurs within the given type.
39703974
///
@@ -4198,6 +4202,10 @@ class ConstraintSystem {
41984202
ConstraintLocatorBuilder locator,
41994203
const OpenedTypeMap &replacements);
42004204

4205+
/// Check whether the given type conforms to the given protocol and if
4206+
/// so return a valid conformance reference.
4207+
ProtocolConformanceRef lookupConformance(Type type, ProtocolDecl *P);
4208+
42014209
/// Wrapper over swift::adjustFunctionTypeForConcurrency that passes along
42024210
/// the appropriate closure-type and opening extraction functions.
42034211
AnyFunctionType *adjustFunctionTypeForConcurrency(
@@ -4711,7 +4719,7 @@ class ConstraintSystem {
47114719
///
47124720
/// The resulting types can be compared canonically, so long as additional
47134721
/// type equivalence requirements aren't introduced between comparisons.
4714-
Type simplifyType(Type type) const;
4722+
Type simplifyType(Type type);
47154723

47164724
/// Simplify a type, by replacing type variables with either their
47174725
/// fixed types (if available) or their representatives.
@@ -4791,8 +4799,9 @@ class ConstraintSystem {
47914799

47924800
/// Simplifies a type by replacing type variables with the result of
47934801
/// \c getFixedTypeFn and performing lookup on dependent member types.
4794-
Type simplifyTypeImpl(Type type,
4795-
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) const;
4802+
Type
4803+
simplifyTypeImpl(Type type,
4804+
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn);
47964805

47974806
/// Attempt to simplify the given construction constraint.
47984807
///

lib/Sema/CSBindings.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ void BindingSet::determineLiteralCoverage() {
669669
Type adjustedTy;
670670

671671
std::tie(isCovered, adjustedTy) =
672-
literal.isCoveredBy(*binding, allowsNil, CS.DC);
672+
literal.isCoveredBy(*binding, allowsNil, CS);
673673

674674
if (!isCovered)
675675
continue;
@@ -872,7 +872,7 @@ void PotentialBindings::addDefault(Constraint *constraint) {
872872
Defaults.insert(constraint);
873873
}
874874

875-
bool LiteralRequirement::isCoveredBy(Type type, DeclContext *useDC) const {
875+
bool LiteralRequirement::isCoveredBy(Type type, ConstraintSystem &CS) const {
876876
auto coversDefaultType = [](Type type, Type defaultType) -> bool {
877877
if (!defaultType->hasUnboundGenericType())
878878
return type->isEqual(defaultType);
@@ -892,14 +892,12 @@ bool LiteralRequirement::isCoveredBy(Type type, DeclContext *useDC) const {
892892
if (hasDefaultType() && coversDefaultType(type, getDefaultType()))
893893
return true;
894894

895-
return (bool)TypeChecker::conformsToProtocol(type, getProtocol(),
896-
useDC->getParentModule());
895+
return bool(CS.lookupConformance(type, getProtocol()));
897896
}
898897

899898
std::pair<bool, Type>
900-
LiteralRequirement::isCoveredBy(const PotentialBinding &binding,
901-
bool canBeNil,
902-
DeclContext *useDC) const {
899+
LiteralRequirement::isCoveredBy(const PotentialBinding &binding, bool canBeNil,
900+
ConstraintSystem &CS) const {
903901
auto type = binding.BindingType;
904902
switch (binding.Kind) {
905903
case AllowedBindingKind::Exact:
@@ -919,7 +917,7 @@ LiteralRequirement::isCoveredBy(const PotentialBinding &binding,
919917
if (type->isTypeVariableOrMember() || type->isPlaceholder())
920918
return std::make_pair(false, Type());
921919

922-
if (isCoveredBy(type, useDC)) {
920+
if (isCoveredBy(type, CS)) {
923921
return std::make_pair(true, requiresUnwrap ? type : binding.BindingType);
924922
}
925923

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ namespace {
425425
// the literal.
426426
if (otherArgTy && otherArgTy->getAnyNominal()) {
427427
if (otherArgTy->isEqual(paramTy) &&
428-
TypeChecker::conformsToProtocol(
429-
otherArgTy, literalProto, CS.DC->getParentModule())) {
428+
CS.lookupConformance(otherArgTy, literalProto)) {
430429
return true;
431430
}
432431
} else if (Type defaultType =

lib/Sema/CSSimplify.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8348,8 +8348,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
83488348
}
83498349

83508350
// Check whether this type conforms to the protocol.
8351-
auto conformance = DC->getParentModule()->lookupConformance(
8352-
type, protocol, /*allowMissing=*/true);
8351+
auto conformance = lookupConformance(type, protocol);
83538352
if (conformance) {
83548353
return recordConformance(conformance);
83558354
}
@@ -8469,8 +8468,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
84698468

84708469
if (auto rawValue = isRawRepresentable(*this, type)) {
84718470
if (!rawValue->isTypeVariableOrMember() &&
8472-
TypeChecker::conformsToProtocol(rawValue, protocol,
8473-
DC->getParentModule())) {
8471+
lookupConformance(rawValue, protocol)) {
84748472
auto *fix = UseRawValue::create(*this, type, protocolTy, loc);
84758473
// Since this is a conformance requirement failure (where the
84768474
// source is most likely an argument), let's increase its impact
@@ -8629,11 +8627,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyTransitivelyConformsTo(
86298627

86308628
auto *protocol = protocolTy->castTo<ProtocolType>()->getDecl();
86318629

8632-
auto *M = DC->getParentModule();
8633-
86348630
// First, let's check whether the type itself conforms,
86358631
// if it does - we are done.
8636-
if (M->lookupConformance(resolvedTy, protocol))
8632+
if (lookupConformance(resolvedTy, protocol))
86378633
return SolutionKind::Solved;
86388634

86398635
// If the type doesn't conform, let's check whether
@@ -8705,10 +8701,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyTransitivelyConformsTo(
87058701
}
87068702
}
87078703

8708-
return llvm::any_of(typesToCheck,
8709-
[&](Type type) {
8710-
return bool(M->lookupConformance(type, protocol));
8711-
})
8704+
return llvm::any_of(
8705+
typesToCheck,
8706+
[&](Type type) { return bool(lookupConformance(type, protocol)); })
87128707
? SolutionKind::Solved
87138708
: SolutionKind::Error;
87148709
}
@@ -9295,7 +9290,7 @@ static bool mayBeForKeyPathSubscriptWithoutLabel(ConstraintSystem &cs,
92959290
/// This is useful to figure out whether it makes sense to
92969291
/// perform dynamic member lookup or not.
92979292
static bool
9298-
allFromConditionalConformances(DeclContext *DC, Type baseTy,
9293+
allFromConditionalConformances(ConstraintSystem &cs, Type baseTy,
92999294
ArrayRef<OverloadChoice> candidates) {
93009295
auto *NTD = baseTy->getAnyNominal();
93019296
if (!NTD)
@@ -9314,8 +9309,7 @@ allFromConditionalConformances(DeclContext *DC, Type baseTy,
93149309
}
93159310

93169311
if (auto *protocol = candidateDC->getSelfProtocolDecl()) {
9317-
auto conformance = DC->getParentModule()->lookupConformance(
9318-
baseTy, protocol);
9312+
auto conformance = cs.lookupConformance(baseTy, protocol);
93199313
if (!conformance.isConcrete())
93209314
return false;
93219315

@@ -10053,7 +10047,7 @@ performMemberLookup(ConstraintKind constraintKind, DeclNameRef memberName,
1005310047
const auto &candidates = result.ViableCandidates;
1005410048

1005510049
if ((candidates.empty() ||
10056-
allFromConditionalConformances(DC, instanceTy, candidates)) &&
10050+
allFromConditionalConformances(*this, instanceTy, candidates)) &&
1005710051
!isSelfRecursiveKeyPathDynamicMemberLookup(*this, baseTy,
1005810052
memberLocator)) {
1005910053
auto &ctx = getASTContext();
@@ -10636,7 +10630,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
1063610630
// called within extensions to that type (usually adding 'clamp').
1063710631
bool treatAsViable =
1063810632
(member.isSimpleName("min") || member.isSimpleName("max")) &&
10639-
allFromConditionalConformances(DC, baseTy, result.ViableCandidates);
10633+
allFromConditionalConformances(*this, baseTy,
10634+
result.ViableCandidates);
1064010635

1064110636
generateConstraints(
1064210637
candidates, memberTy, outerAlternatives, useDC, locator, llvm::None,
@@ -11044,8 +11039,7 @@ ConstraintSystem::simplifyValueWitnessConstraint(
1104411039
// conformance already?
1104511040
auto proto = requirement->getDeclContext()->getSelfProtocolDecl();
1104611041
assert(proto && "Value witness constraint for a non-requirement");
11047-
auto conformance = useDC->getParentModule()->lookupConformance(
11048-
baseObjectType, proto);
11042+
auto conformance = lookupConformance(baseObjectType, proto);
1104911043
if (!conformance)
1105011044
return fail();
1105111045

lib/Sema/CSSolver.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,8 +2175,7 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
21752175
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
21762176
return refined->inheritsFrom(protocol);
21772177

2178-
return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol,
2179-
CS.DC->getParentModule());
2178+
return bool(CS.lookupConformance(nominal->getDeclaredType(), protocol));
21802179
};
21812180

21822181
// Gather Numeric and Sequence overloads into separate buckets.

lib/Sema/ConstraintSystem.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,9 +1196,8 @@ llvm::Optional<Type> ConstraintSystem::isSetType(Type type) {
11961196
return llvm::None;
11971197
}
11981198

1199-
Type ConstraintSystem::getFixedTypeRecursive(Type type,
1200-
TypeMatchOptions &flags,
1201-
bool wantRValue) const {
1199+
Type ConstraintSystem::getFixedTypeRecursive(Type type, TypeMatchOptions &flags,
1200+
bool wantRValue) {
12021201

12031202
if (wantRValue)
12041203
type = type->getRValueType();
@@ -3826,7 +3825,7 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator,
38263825
namespace {
38273826

38283827
struct TypeSimplifier {
3829-
const ConstraintSystem &CS;
3828+
ConstraintSystem &CS;
38303829
llvm::function_ref<Type(TypeVariableType *)> GetFixedTypeFn;
38313830

38323831
struct ActivePackExpansion {
@@ -3835,7 +3834,7 @@ struct TypeSimplifier {
38353834
};
38363835
SmallVector<ActivePackExpansion, 4> ActivePackExpansions;
38373836

3838-
TypeSimplifier(const ConstraintSystem &CS,
3837+
TypeSimplifier(ConstraintSystem &CS,
38393838
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn)
38403839
: CS(CS), GetFixedTypeFn(getFixedTypeFn) {}
38413840

@@ -3983,8 +3982,7 @@ struct TypeSimplifier {
39833982
if (lookupBaseType->mayHaveMembers() ||
39843983
lookupBaseType->is<PackType>()) {
39853984
auto *proto = assocType->getProtocol();
3986-
auto conformance = CS.DC->getParentModule()->lookupConformance(
3987-
lookupBaseType, proto);
3985+
auto conformance = CS.lookupConformance(lookupBaseType, proto);
39883986
if (!conformance) {
39893987
// If the base type doesn't conform to the associatedtype's protocol,
39903988
// there will be a missing conformance fix applied in diagnostic mode,
@@ -4015,11 +4013,11 @@ struct TypeSimplifier {
40154013
} // end anonymous namespace
40164014

40174015
Type ConstraintSystem::simplifyTypeImpl(Type type,
4018-
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) const {
4016+
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) {
40194017
return type.transform(TypeSimplifier(*this, getFixedTypeFn));
40204018
}
40214019

4022-
Type ConstraintSystem::simplifyType(Type type) const {
4020+
Type ConstraintSystem::simplifyType(Type type) {
40234021
if (!type->hasTypeVariable())
40244022
return type;
40254023

@@ -6102,15 +6100,12 @@ bool constraints::hasAppliedSelf(const OverloadChoice &choice,
61026100
/// Check whether given type conforms to `RawRepresentable` protocol
61036101
/// and return the witness type.
61046102
Type constraints::isRawRepresentable(ConstraintSystem &cs, Type type) {
6105-
auto *DC = cs.DC;
6106-
61076103
auto rawReprType = TypeChecker::getProtocol(
61086104
cs.getASTContext(), SourceLoc(), KnownProtocolKind::RawRepresentable);
61096105
if (!rawReprType)
61106106
return Type();
61116107

6112-
auto conformance = TypeChecker::conformsToProtocol(type, rawReprType,
6113-
DC->getParentModule());
6108+
auto conformance = cs.lookupConformance(type, rawReprType);
61146109
if (conformance.isInvalid())
61156110
return Type();
61166111

@@ -7346,6 +7341,21 @@ bool ConstraintSystem::participatesInInference(ClosureExpr *closure) const {
73467341
return true;
73477342
}
73487343

7344+
ProtocolConformanceRef
7345+
ConstraintSystem::lookupConformance(Type type, ProtocolDecl *protocol) {
7346+
auto cacheKey = std::make_pair(type.getPointer(), protocol);
7347+
7348+
auto cachedConformance = Conformances.find(cacheKey);
7349+
if (cachedConformance != Conformances.end())
7350+
return cachedConformance->second;
7351+
7352+
auto conformance =
7353+
DC->getParentModule()->lookupConformance(type, protocol,
7354+
/*allowMissing=*/true);
7355+
Conformances[cacheKey] = conformance;
7356+
return conformance;
7357+
}
7358+
73497359
TypeVarBindingProducer::TypeVarBindingProducer(BindingSet &bindings)
73507360
: BindingProducer(bindings.getConstraintSystem(),
73517361
bindings.getTypeVariable()->getImpl().getLocator()),
@@ -7478,7 +7488,7 @@ bool TypeVarBindingProducer::requiresOptionalAdjustment(
74787488
auto *proto = CS.getASTContext().getProtocol(
74797489
KnownProtocolKind::ExpressibleByNilLiteral);
74807490

7481-
return !proto->getParentModule()->lookupConformance(type, proto);
7491+
return !CS.lookupConformance(type, proto);
74827492
} else if (binding.isDefaultableBinding() && binding.BindingType->isAny()) {
74837493
return true;
74847494
}

test/Constraints/array_literal.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ protocol P { }
358358
struct PArray<T> { }
359359

360360
extension PArray : ExpressibleByArrayLiteral where T: P {
361+
// expected-note@-1 {{requirement from conditional conformance of 'PArray<String>' to 'ExpressibleByArrayLiteral'}}
361362
typealias ArrayLiteralElement = T
362363

363364
init(arrayLiteral elements: T...) { }
@@ -367,7 +368,7 @@ extension Int: P { }
367368

368369
func testConditional(i: Int, s: String) {
369370
let _: PArray<Int> = [i, i, i]
370-
let _: PArray<String> = [s, s, s] // expected-error{{cannot convert value of type '[String]' to specified type 'PArray<String>'}}
371+
let _: PArray<String> = [s, s, s] // expected-error{{generic struct 'PArray' requires that 'String' conform to 'P'}}
371372
}
372373

373374

0 commit comments

Comments
 (0)