Skip to content

Commit 1ef2174

Browse files
authored
Merge pull request swiftlang#35025 from hborla/generic-overload-search-space
[ConstraintSystem] Implement heuristics for pruning the generic operator overload search space
2 parents 1f1123f + a7b5476 commit 1ef2174

File tree

10 files changed

+347
-42
lines changed

10 files changed

+347
-42
lines changed

include/swift/AST/Identifier.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class Identifier {
110110
return isOperatorSlow();
111111
}
112112

113+
bool isArithmeticOperator() const {
114+
return is("+") || is("-") || is("*") || is("/") || is("%");
115+
}
116+
113117
// Returns whether this is a standard comparison operator,
114118
// such as '==', '>=' or '!=='.
115119
bool isStandardComparisonOperator() const {

include/swift/AST/TypeCheckRequests.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,27 @@ class CompareDeclSpecializationRequest
19001900
bool isCached() const { return true; }
19011901
};
19021902

1903+
/// Checks whether the first function decl is a refinement of the second,
1904+
/// meaning the two functions have the same structure, and the requirements
1905+
/// of the first are refining the requirements of the second.
1906+
class IsDeclRefinementOfRequest
1907+
: public SimpleRequest<IsDeclRefinementOfRequest,
1908+
bool(ValueDecl *, ValueDecl *),
1909+
RequestFlags::Cached> {
1910+
public:
1911+
using SimpleRequest::SimpleRequest;
1912+
1913+
private:
1914+
friend SimpleRequest;
1915+
1916+
// Evaluation.
1917+
bool evaluate(Evaluator &evaluator, ValueDecl *declA, ValueDecl *declB) const;
1918+
1919+
public:
1920+
// Caching.
1921+
bool isCached() const { return true; }
1922+
};
1923+
19031924
/// Checks whether this declaration inherits its superclass' designated and
19041925
/// convenience initializers.
19051926
class InheritsSuperclassInitializersRequest

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ SWIFT_REQUEST(TypeChecker, CodeCompletionFileRequest,
4141
SWIFT_REQUEST(TypeChecker, CompareDeclSpecializationRequest,
4242
bool (DeclContext *, ValueDecl *, ValueDecl *, bool), Cached,
4343
NoLocationInfo)
44+
SWIFT_REQUEST(TypeChecker, IsDeclRefinementOfRequest,
45+
bool (ValueDecl *, ValueDecl *),
46+
Cached, NoLocationInfo)
4447
SWIFT_REQUEST(TypeChecker, CustomAttrTypeRequest,
4548
Type(CustomAttr *, DeclContext *, CustomAttrTypeKind),
4649
SeparatelyCached, NoLocationInfo)

include/swift/Sema/ConstraintSystem.h

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,11 @@ class ConstraintSystem {
21642164
std::vector<std::pair<ConstraintLocator*, unsigned>>
21652165
DisjunctionChoices;
21662166

2167+
/// A map from applied disjunction constraints to the corresponding
2168+
/// argument function type.
2169+
llvm::SmallMapVector<ConstraintLocator *, const FunctionType *, 4>
2170+
AppliedDisjunctions;
2171+
21672172
/// For locators associated with call expressions, the trailing closure
21682173
/// matching rule that was applied.
21692174
std::vector<std::pair<ConstraintLocator*, TrailingClosureMatching>>
@@ -2653,6 +2658,9 @@ class ConstraintSystem {
26532658
/// The length of \c DisjunctionChoices.
26542659
unsigned numDisjunctionChoices;
26552660

2661+
/// The length of \c AppliedDisjunctions.
2662+
unsigned numAppliedDisjunctions;
2663+
26562664
/// The length of \c trailingClosureMatchingChoices;
26572665
unsigned numTrailingClosureMatchingChoices;
26582666

@@ -4954,6 +4962,21 @@ class ConstraintSystem {
49544962
SmallVectorImpl<unsigned> &Ordering,
49554963
SmallVectorImpl<unsigned> &PartitionBeginning);
49564964

4965+
/// Partition the choices in the range \c first to \c last into groups and
4966+
/// order the groups in the best order to attempt based on the argument
4967+
/// function type that the operator is applied to.
4968+
void partitionGenericOperators(ArrayRef<Constraint *> Choices,
4969+
SmallVectorImpl<unsigned>::iterator first,
4970+
SmallVectorImpl<unsigned>::iterator last,
4971+
ConstraintLocator *locator);
4972+
4973+
// If the given constraint is an applied disjunction, get the argument function
4974+
// that the disjunction is applied to.
4975+
const FunctionType *getAppliedDisjunctionArgumentFunction(Constraint *disjunction) {
4976+
assert(disjunction->getKind() == ConstraintKind::Disjunction);
4977+
return AppliedDisjunctions[disjunction->getLocator()];
4978+
}
4979+
49574980
/// The overload sets that have already been resolved along the current path.
49584981
const llvm::MapVector<ConstraintLocator *, SelectedOverload> &
49594982
getResolvedOverloads() const {
@@ -5454,8 +5477,12 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
54545477

54555478
bool IsExplicitConversion;
54565479

5480+
Constraint *Disjunction;
5481+
54575482
unsigned Index = 0;
54585483

5484+
bool needsGenericOperatorOrdering = true;
5485+
54595486
public:
54605487
using Element = DisjunctionChoice;
54615488

@@ -5464,22 +5491,17 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
54645491
? disjunction->getLocator()
54655492
: nullptr),
54665493
Choices(disjunction->getNestedConstraints()),
5467-
IsExplicitConversion(disjunction->isExplicitConversion()) {
5494+
IsExplicitConversion(disjunction->isExplicitConversion()),
5495+
Disjunction(disjunction) {
54685496
assert(disjunction->getKind() == ConstraintKind::Disjunction);
54695497
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
54705498

54715499
// Order and partition the disjunction choices.
54725500
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
54735501
}
54745502

5475-
DisjunctionChoiceProducer(ConstraintSystem &cs,
5476-
ArrayRef<Constraint *> choices,
5477-
ConstraintLocator *locator, bool explicitConversion)
5478-
: BindingProducer(cs, locator), Choices(choices),
5479-
IsExplicitConversion(explicitConversion) {
5480-
5481-
// Order and partition the disjunction choices.
5482-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
5503+
void setNeedsGenericOperatorOrdering(bool flag) {
5504+
needsGenericOperatorOrdering = flag;
54835505
}
54845506

54855507
Optional<Element> operator()() override {
@@ -5494,6 +5516,20 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
54945516

54955517
++Index;
54965518

5519+
auto choice = DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
5520+
IsExplicitConversion, isBeginningOfPartition);
5521+
// Partition the generic operators before producing the first generic
5522+
// operator disjunction choice.
5523+
if (needsGenericOperatorOrdering && choice.isGenericOperator()) {
5524+
unsigned nextPartitionIndex = (PartitionIndex < PartitionBeginning.size() ?
5525+
PartitionBeginning[PartitionIndex] : Ordering.size());
5526+
CS.partitionGenericOperators(Choices,
5527+
Ordering.begin() + currIndex,
5528+
Ordering.begin() + nextPartitionIndex,
5529+
Disjunction->getLocator());
5530+
needsGenericOperatorOrdering = false;
5531+
}
5532+
54975533
return DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
54985534
IsExplicitConversion, isBeginningOfPartition);
54995535
}

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ using namespace swift;
3838
using namespace swift::constraints;
3939

4040
static bool isArithmeticOperatorDecl(ValueDecl *vd) {
41-
return vd &&
42-
(vd->getBaseName() == "+" ||
43-
vd->getBaseName() == "-" ||
44-
vd->getBaseName() == "*" ||
45-
vd->getBaseName() == "/" ||
46-
vd->getBaseName() == "%");
41+
return vd && vd->getBaseIdentifier().isArithmeticOperator();
4742
}
4843

4944
static bool mergeRepresentativeEquivalenceClasses(ConstraintSystem &CS,

lib/Sema/CSSimplify.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8919,6 +8919,7 @@ bool ConstraintSystem::simplifyAppliedOverloads(
89198919
auto *applicableFn = result->first;
89208920
auto *fnTypeVar = applicableFn->getSecondType()->castTo<TypeVariableType>();
89218921
auto argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
8922+
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
89228923
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
89238924
/*numOptionalUnwraps*/ result->second,
89248925
locator);
@@ -8938,6 +8939,8 @@ bool ConstraintSystem::simplifyAppliedOverloads(
89388939
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
89398940
if (!disjunction)
89408941
return false;
8942+
8943+
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
89418944
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
89428945
numOptionalUnwraps, locator);
89438946
}

lib/Sema/CSSolver.cpp

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
465465
numFixes = cs.Fixes.size();
466466
numFixedRequirements = cs.FixedRequirements.size();
467467
numDisjunctionChoices = cs.DisjunctionChoices.size();
468+
numAppliedDisjunctions = cs.AppliedDisjunctions.size();
468469
numTrailingClosureMatchingChoices = cs.trailingClosureMatchingChoices.size();
469470
numOpenedTypes = cs.OpenedTypes.size();
470471
numOpenedExistentialTypes = cs.OpenedExistentialTypes.size();
@@ -519,6 +520,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
519520
// Remove any disjunction choices.
520521
truncate(cs.DisjunctionChoices, numDisjunctionChoices);
521522

523+
// Remove any applied disjunctions.
524+
truncate(cs.AppliedDisjunctions, numAppliedDisjunctions);
525+
522526
// Remove any trailing closure matching choices;
523527
truncate(
524528
cs.trailingClosureMatchingChoices, numTrailingClosureMatchingChoices);
@@ -2063,6 +2067,95 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
20632067
}
20642068
}
20652069

2070+
void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constraints,
2071+
SmallVectorImpl<unsigned>::iterator first,
2072+
SmallVectorImpl<unsigned>::iterator last,
2073+
ConstraintLocator *locator) {
2074+
auto *argFnType = AppliedDisjunctions[locator];
2075+
if (!isOperatorBindOverload(constraints[0]) || !argFnType)
2076+
return;
2077+
2078+
auto operatorName = constraints[0]->getOverloadChoice().getName();
2079+
if (!operatorName.getBaseIdentifier().isArithmeticOperator())
2080+
return;
2081+
2082+
SmallVector<unsigned, 4> concreteOverloads;
2083+
SmallVector<unsigned, 4> numericOverloads;
2084+
SmallVector<unsigned, 4> sequenceOverloads;
2085+
SmallVector<unsigned, 4> otherGenericOverloads;
2086+
2087+
auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
2088+
if (!nominal)
2089+
return false;
2090+
2091+
auto *protocol = TypeChecker::getProtocol(getASTContext(), SourceLoc(), kind);
2092+
2093+
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
2094+
return refined->inheritsFrom(protocol);
2095+
2096+
return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol,
2097+
nominal->getDeclContext());
2098+
};
2099+
2100+
// Gather Numeric and Sequence overloads into separate buckets.
2101+
for (auto iter = first; iter != last; ++iter) {
2102+
unsigned index = *iter;
2103+
auto *decl = constraints[index]->getOverloadChoice().getDecl();
2104+
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
2105+
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
2106+
concreteOverloads.push_back(index);
2107+
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) {
2108+
numericOverloads.push_back(index);
2109+
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::Sequence)) {
2110+
sequenceOverloads.push_back(index);
2111+
} else {
2112+
otherGenericOverloads.push_back(index);
2113+
}
2114+
}
2115+
2116+
auto sortPartition = [&](SmallVectorImpl<unsigned> &partition) {
2117+
llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool {
2118+
auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice().getDecl());
2119+
auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice().getDecl());
2120+
2121+
return TypeChecker::isDeclRefinementOf(declA, declB);
2122+
});
2123+
};
2124+
2125+
// Sort sequence overloads so that refinements are attempted first.
2126+
// If the solver finds a solution with an overload, it can then skip
2127+
// subsequent choices that the successful choice is a refinement of.
2128+
sortPartition(sequenceOverloads);
2129+
2130+
// Attempt concrete overloads first.
2131+
first = std::copy(concreteOverloads.begin(), concreteOverloads.end(), first);
2132+
2133+
// Check if any of the known argument types conform to one of the standard
2134+
// arithmetic protocols. If so, the sovler should attempt the corresponding
2135+
// overload choices first.
2136+
for (auto arg : argFnType->getParams()) {
2137+
auto argType = arg.getPlainType();
2138+
if (!argType || argType->hasTypeVariable())
2139+
continue;
2140+
2141+
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
2142+
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
2143+
numericOverloads.clear();
2144+
break;
2145+
}
2146+
2147+
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::Sequence)) {
2148+
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
2149+
sequenceOverloads.clear();
2150+
break;
2151+
}
2152+
}
2153+
2154+
first = std::copy(otherGenericOverloads.begin(), otherGenericOverloads.end(), first);
2155+
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
2156+
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
2157+
}
2158+
20662159
void ConstraintSystem::partitionDisjunction(
20672160
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
20682161
SmallVectorImpl<unsigned> &PartitionBeginning) {
@@ -2153,6 +2246,12 @@ void ConstraintSystem::partitionDisjunction(
21532246
});
21542247
}
21552248

2249+
// Gather the remaining options.
2250+
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
2251+
everythingElse.push_back(index);
2252+
return true;
2253+
});
2254+
21562255
// Local function to create the next partition based on the options
21572256
// passed in.
21582257
PartitionAppendCallback appendPartition =
@@ -2163,16 +2262,9 @@ void ConstraintSystem::partitionDisjunction(
21632262
}
21642263
};
21652264

2166-
// Gather the remaining options.
2167-
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
2168-
everythingElse.push_back(index);
2169-
return true;
2170-
});
21712265
appendPartition(favored);
21722266
appendPartition(everythingElse);
21732267
appendPartition(simdOperators);
2174-
2175-
// Now create the remaining partitions from what we previously collected.
21762268
appendPartition(unavailable);
21772269
appendPartition(disabled);
21782270

0 commit comments

Comments
 (0)