Skip to content

[ConstraintSystem] Add cache for conformance lookups #63889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions include/swift/Sema/CSBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,22 @@ struct LiteralRequirement {
/// \param canBeNil The flag that determines whether given type
/// variable requires all of its bindings to be optional.
///
/// \param useDC The declaration context in which this literal
/// requirement is used.
/// \param CS The constraint system this literal requirement belongs to.
///
/// \returns a pair of bool and a type:
/// - bool, true if binding covers given literal protocol;
/// - type, non-null if binding type has to be adjusted
/// to cover given literal protocol;
std::pair<bool, Type> isCoveredBy(const PotentialBinding &binding,
bool canBeNil,
DeclContext *useDC) const;
ConstraintSystem &CS) const;

/// Determines whether literal protocol associated with this
/// meta-information is viable for inclusion as a defaultable binding.
bool viableAsBinding() const { return !isCovered() && hasDefaultType(); }

private:
bool isCoveredBy(Type type, DeclContext *useDC) const;
bool isCoveredBy(Type type, ConstraintSystem &CS) const;
};

struct PotentialBindings {
Expand Down
19 changes: 14 additions & 5 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,10 @@ class ConstraintSystem {
llvm::SmallMapVector<ConstraintLocator *, ArrayRef<OpenedType>, 4>
OpenedTypes;

/// A dictionary of all conformances that have been looked up by the solver.
llvm::DenseMap<std::pair<TypeBase *, ProtocolDecl *>, ProtocolConformanceRef>
Conformances;

/// The list of all generic requirements fixed along the current
/// solver path.
using FixedRequirement =
Expand Down Expand Up @@ -3946,7 +3950,7 @@ class ConstraintSystem {
///
/// \param wantRValue Whether this routine should look through
/// lvalues at each step.
Type getFixedTypeRecursive(Type type, bool wantRValue) const {
Type getFixedTypeRecursive(Type type, bool wantRValue) {
TypeMatchOptions flags = llvm::None;
return getFixedTypeRecursive(type, flags, wantRValue);
}
Expand All @@ -3964,7 +3968,7 @@ class ConstraintSystem {
/// \param wantRValue Whether this routine should look through
/// lvalues at each step.
Type getFixedTypeRecursive(Type type, TypeMatchOptions &flags,
bool wantRValue) const;
bool wantRValue);

/// Determine whether the given type variable occurs within the given type.
///
Expand Down Expand Up @@ -4198,6 +4202,10 @@ class ConstraintSystem {
ConstraintLocatorBuilder locator,
const OpenedTypeMap &replacements);

/// Check whether the given type conforms to the given protocol and if
/// so return a valid conformance reference.
ProtocolConformanceRef lookupConformance(Type type, ProtocolDecl *P);

/// Wrapper over swift::adjustFunctionTypeForConcurrency that passes along
/// the appropriate closure-type and opening extraction functions.
AnyFunctionType *adjustFunctionTypeForConcurrency(
Expand Down Expand Up @@ -4711,7 +4719,7 @@ class ConstraintSystem {
///
/// The resulting types can be compared canonically, so long as additional
/// type equivalence requirements aren't introduced between comparisons.
Type simplifyType(Type type) const;
Type simplifyType(Type type);

/// Simplify a type, by replacing type variables with either their
/// fixed types (if available) or their representatives.
Expand Down Expand Up @@ -4791,8 +4799,9 @@ class ConstraintSystem {

/// Simplifies a type by replacing type variables with the result of
/// \c getFixedTypeFn and performing lookup on dependent member types.
Type simplifyTypeImpl(Type type,
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) const;
Type
simplifyTypeImpl(Type type,
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn);

/// Attempt to simplify the given construction constraint.
///
Expand Down
14 changes: 6 additions & 8 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ void BindingSet::determineLiteralCoverage() {
Type adjustedTy;

std::tie(isCovered, adjustedTy) =
literal.isCoveredBy(*binding, allowsNil, CS.DC);
literal.isCoveredBy(*binding, allowsNil, CS);

if (!isCovered)
continue;
Expand Down Expand Up @@ -872,7 +872,7 @@ void PotentialBindings::addDefault(Constraint *constraint) {
Defaults.insert(constraint);
}

bool LiteralRequirement::isCoveredBy(Type type, DeclContext *useDC) const {
bool LiteralRequirement::isCoveredBy(Type type, ConstraintSystem &CS) const {
auto coversDefaultType = [](Type type, Type defaultType) -> bool {
if (!defaultType->hasUnboundGenericType())
return type->isEqual(defaultType);
Expand All @@ -892,14 +892,12 @@ bool LiteralRequirement::isCoveredBy(Type type, DeclContext *useDC) const {
if (hasDefaultType() && coversDefaultType(type, getDefaultType()))
return true;

return (bool)TypeChecker::conformsToProtocol(type, getProtocol(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conformsToProtocol() also checks conditional requirements; is that an issue here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually better if it doesn't check them here, leads to better diagnostics.

useDC->getParentModule());
return bool(CS.lookupConformance(type, getProtocol()));
}

std::pair<bool, Type>
LiteralRequirement::isCoveredBy(const PotentialBinding &binding,
bool canBeNil,
DeclContext *useDC) const {
LiteralRequirement::isCoveredBy(const PotentialBinding &binding, bool canBeNil,
ConstraintSystem &CS) const {
auto type = binding.BindingType;
switch (binding.Kind) {
case AllowedBindingKind::Exact:
Expand All @@ -919,7 +917,7 @@ LiteralRequirement::isCoveredBy(const PotentialBinding &binding,
if (type->isTypeVariableOrMember() || type->isPlaceholder())
return std::make_pair(false, Type());

if (isCoveredBy(type, useDC)) {
if (isCoveredBy(type, CS)) {
return std::make_pair(true, requiresUnwrap ? type : binding.BindingType);
}

Expand Down
3 changes: 1 addition & 2 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,7 @@ namespace {
// the literal.
if (otherArgTy && otherArgTy->getAnyNominal()) {
if (otherArgTy->isEqual(paramTy) &&
TypeChecker::conformsToProtocol(
otherArgTy, literalProto, CS.DC->getParentModule())) {
CS.lookupConformance(otherArgTy, literalProto)) {
return true;
}
} else if (Type defaultType =
Expand Down
30 changes: 12 additions & 18 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8425,8 +8425,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
}

// Check whether this type conforms to the protocol.
auto conformance = DC->getParentModule()->lookupConformance(
type, protocol, /*allowMissing=*/true);
auto conformance = lookupConformance(type, protocol);
if (conformance) {
return recordConformance(conformance);
}
Expand Down Expand Up @@ -8546,8 +8545,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(

if (auto rawValue = isRawRepresentable(*this, type)) {
if (!rawValue->isTypeVariableOrMember() &&
TypeChecker::conformsToProtocol(rawValue, protocol,
DC->getParentModule())) {
lookupConformance(rawValue, protocol)) {
auto *fix = UseRawValue::create(*this, type, protocolTy, loc);
// Since this is a conformance requirement failure (where the
// source is most likely an argument), let's increase its impact
Expand Down Expand Up @@ -8706,11 +8704,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyTransitivelyConformsTo(

auto *protocol = protocolTy->castTo<ProtocolType>()->getDecl();

auto *M = DC->getParentModule();

// First, let's check whether the type itself conforms,
// if it does - we are done.
if (M->lookupConformance(resolvedTy, protocol))
if (lookupConformance(resolvedTy, protocol))
return SolutionKind::Solved;

// If the type doesn't conform, let's check whether
Expand Down Expand Up @@ -8782,10 +8778,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyTransitivelyConformsTo(
}
}

return llvm::any_of(typesToCheck,
[&](Type type) {
return bool(M->lookupConformance(type, protocol));
})
return llvm::any_of(
typesToCheck,
[&](Type type) { return bool(lookupConformance(type, protocol)); })
? SolutionKind::Solved
: SolutionKind::Error;
}
Expand Down Expand Up @@ -9372,7 +9367,7 @@ static bool mayBeForKeyPathSubscriptWithoutLabel(ConstraintSystem &cs,
/// This is useful to figure out whether it makes sense to
/// perform dynamic member lookup or not.
static bool
allFromConditionalConformances(DeclContext *DC, Type baseTy,
allFromConditionalConformances(ConstraintSystem &cs, Type baseTy,
ArrayRef<OverloadChoice> candidates) {
auto *NTD = baseTy->getAnyNominal();
if (!NTD)
Expand All @@ -9391,8 +9386,7 @@ allFromConditionalConformances(DeclContext *DC, Type baseTy,
}

if (auto *protocol = candidateDC->getSelfProtocolDecl()) {
auto conformance = DC->getParentModule()->lookupConformance(
baseTy, protocol);
auto conformance = cs.lookupConformance(baseTy, protocol);
if (!conformance.isConcrete())
return false;

Expand Down Expand Up @@ -10130,7 +10124,7 @@ performMemberLookup(ConstraintKind constraintKind, DeclNameRef memberName,
const auto &candidates = result.ViableCandidates;

if ((candidates.empty() ||
allFromConditionalConformances(DC, instanceTy, candidates)) &&
allFromConditionalConformances(*this, instanceTy, candidates)) &&
!isSelfRecursiveKeyPathDynamicMemberLookup(*this, baseTy,
memberLocator)) {
auto &ctx = getASTContext();
Expand Down Expand Up @@ -10713,7 +10707,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
// called within extensions to that type (usually adding 'clamp').
bool treatAsViable =
(member.isSimpleName("min") || member.isSimpleName("max")) &&
allFromConditionalConformances(DC, baseTy, result.ViableCandidates);
allFromConditionalConformances(*this, baseTy,
result.ViableCandidates);

generateConstraints(
candidates, memberTy, outerAlternatives, useDC, locator, llvm::None,
Expand Down Expand Up @@ -11121,8 +11116,7 @@ ConstraintSystem::simplifyValueWitnessConstraint(
// conformance already?
auto proto = requirement->getDeclContext()->getSelfProtocolDecl();
assert(proto && "Value witness constraint for a non-requirement");
auto conformance = useDC->getParentModule()->lookupConformance(
baseObjectType, proto);
auto conformance = lookupConformance(baseObjectType, proto);
if (!conformance)
return fail();

Expand Down
3 changes: 1 addition & 2 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,8 +2175,7 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
return refined->inheritsFrom(protocol);

return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol,
CS.DC->getParentModule());
return bool(CS.lookupConformance(nominal->getDeclaredType(), protocol));
};

// Gather Numeric and Sequence overloads into separate buckets.
Expand Down
38 changes: 24 additions & 14 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,9 +1196,8 @@ llvm::Optional<Type> ConstraintSystem::isSetType(Type type) {
return llvm::None;
}

Type ConstraintSystem::getFixedTypeRecursive(Type type,
TypeMatchOptions &flags,
bool wantRValue) const {
Type ConstraintSystem::getFixedTypeRecursive(Type type, TypeMatchOptions &flags,
bool wantRValue) {

if (wantRValue)
type = type->getRValueType();
Expand Down Expand Up @@ -3822,7 +3821,7 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator,
namespace {

struct TypeSimplifier {
const ConstraintSystem &CS;
ConstraintSystem &CS;
llvm::function_ref<Type(TypeVariableType *)> GetFixedTypeFn;

struct ActivePackExpansion {
Expand All @@ -3831,7 +3830,7 @@ struct TypeSimplifier {
};
SmallVector<ActivePackExpansion, 4> ActivePackExpansions;

TypeSimplifier(const ConstraintSystem &CS,
TypeSimplifier(ConstraintSystem &CS,
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn)
: CS(CS), GetFixedTypeFn(getFixedTypeFn) {}

Expand Down Expand Up @@ -3979,8 +3978,7 @@ struct TypeSimplifier {
if (lookupBaseType->mayHaveMembers() ||
lookupBaseType->is<PackType>()) {
auto *proto = assocType->getProtocol();
auto conformance = CS.DC->getParentModule()->lookupConformance(
lookupBaseType, proto);
auto conformance = CS.lookupConformance(lookupBaseType, proto);
if (!conformance) {
// If the base type doesn't conform to the associatedtype's protocol,
// there will be a missing conformance fix applied in diagnostic mode,
Expand Down Expand Up @@ -4011,11 +4009,11 @@ struct TypeSimplifier {
} // end anonymous namespace

Type ConstraintSystem::simplifyTypeImpl(Type type,
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) const {
llvm::function_ref<Type(TypeVariableType *)> getFixedTypeFn) {
return type.transform(TypeSimplifier(*this, getFixedTypeFn));
}

Type ConstraintSystem::simplifyType(Type type) const {
Type ConstraintSystem::simplifyType(Type type) {
if (!type->hasTypeVariable())
return type;

Expand Down Expand Up @@ -6098,15 +6096,12 @@ bool constraints::hasAppliedSelf(const OverloadChoice &choice,
/// Check whether given type conforms to `RawRepresentable` protocol
/// and return the witness type.
Type constraints::isRawRepresentable(ConstraintSystem &cs, Type type) {
auto *DC = cs.DC;

auto rawReprType = TypeChecker::getProtocol(
cs.getASTContext(), SourceLoc(), KnownProtocolKind::RawRepresentable);
if (!rawReprType)
return Type();

auto conformance = TypeChecker::conformsToProtocol(type, rawReprType,
DC->getParentModule());
auto conformance = cs.lookupConformance(type, rawReprType);
if (conformance.isInvalid())
return Type();

Expand Down Expand Up @@ -7342,6 +7337,21 @@ bool ConstraintSystem::participatesInInference(ClosureExpr *closure) const {
return true;
}

ProtocolConformanceRef
ConstraintSystem::lookupConformance(Type type, ProtocolDecl *protocol) {
auto cacheKey = std::make_pair(type.getPointer(), protocol);

auto cachedConformance = Conformances.find(cacheKey);
if (cachedConformance != Conformances.end())
return cachedConformance->second;

auto conformance =
DC->getParentModule()->lookupConformance(type, protocol,
/*allowMissing=*/true);
Conformances[cacheKey] = conformance;
return conformance;
}

TypeVarBindingProducer::TypeVarBindingProducer(BindingSet &bindings)
: BindingProducer(bindings.getConstraintSystem(),
bindings.getTypeVariable()->getImpl().getLocator()),
Expand Down Expand Up @@ -7474,7 +7484,7 @@ bool TypeVarBindingProducer::requiresOptionalAdjustment(
auto *proto = CS.getASTContext().getProtocol(
KnownProtocolKind::ExpressibleByNilLiteral);

return !proto->getParentModule()->lookupConformance(type, proto);
return !CS.lookupConformance(type, proto);
} else if (binding.isDefaultableBinding() && binding.BindingType->isAny()) {
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion test/Constraints/array_literal.swift
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ protocol P { }
struct PArray<T> { }

extension PArray : ExpressibleByArrayLiteral where T: P {
// expected-note@-1 {{requirement from conditional conformance of 'PArray<String>' to 'ExpressibleByArrayLiteral'}}
typealias ArrayLiteralElement = T

init(arrayLiteral elements: T...) { }
Expand All @@ -367,7 +368,7 @@ extension Int: P { }

func testConditional(i: Int, s: String) {
let _: PArray<Int> = [i, i, i]
let _: PArray<String> = [s, s, s] // expected-error{{cannot convert value of type '[String]' to specified type 'PArray<String>'}}
let _: PArray<String> = [s, s, s] // expected-error{{generic struct 'PArray' requires that 'String' conform to 'P'}}
}


Expand Down