Skip to content

Commit 19350d9

Browse files
authored
Merge pull request #63585 from xedin/solver-perf
[CSSolver] Implementation of disjunction choice favoring algorithm
2 parents 5dbe202 + bc3a15f commit 19350d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2066
-1715
lines changed

include/swift/Sema/Constraint.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -825,11 +825,6 @@ class Constraint final : public llvm::ilist_node<Constraint>,
825825
});
826826
}
827827

828-
/// Returns the number of resolved argument types for an applied disjunction
829-
/// constraint. This is always zero for disjunctions that do not represent
830-
/// an applied overload.
831-
unsigned countResolvedArgumentTypes(ConstraintSystem &cs) const;
832-
833828
/// Determine if this constraint represents explicit conversion,
834829
/// e.g. coercion constraint "as X" which forms a disjunction.
835830
bool isExplicitConversion() const;

include/swift/Sema/ConstraintSystem.h

Lines changed: 21 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,14 @@ class TypeVariableType::Implementation {
497497
/// literal (represented by `ArrayExpr` and `DictionaryExpr` in AST).
498498
bool isCollectionLiteralType() const;
499499

500+
/// Determine whether this type variable represents a literal such
501+
/// as an integer value, a floating-point value with and without a sign.
502+
bool isNumberLiteralType() const;
503+
504+
/// Determine whether this type variable represents a result type of a
505+
/// function call.
506+
bool isFunctionResult() const;
507+
500508
/// Retrieve the representative of the equivalence class to which this
501509
/// type variable belongs.
502510
///
@@ -2203,10 +2211,6 @@ class ConstraintSystem {
22032211

22042212
llvm::SetVector<TypeVariableType *> TypeVariables;
22052213

2206-
/// Maps expressions to types for choosing a favored overload
2207-
/// type in a disjunction constraint.
2208-
llvm::DenseMap<Expr *, TypeBase *> FavoredTypes;
2209-
22102214
/// Maps discovered closures to their types inferred
22112215
/// from declared parameters/result and body.
22122216
///
@@ -2424,75 +2428,6 @@ class ConstraintSystem {
24242428
SynthesizedConformances;
24252429

24262430
private:
2427-
/// Describe the candidate expression for partial solving.
2428-
/// This class used by shrink & solve methods which apply
2429-
/// variation of directional path consistency algorithm in attempt
2430-
/// to reduce scopes of the overload sets (disjunctions) in the system.
2431-
class Candidate {
2432-
Expr *E;
2433-
DeclContext *DC;
2434-
llvm::BumpPtrAllocator &Allocator;
2435-
2436-
// Contextual Information.
2437-
Type CT;
2438-
ContextualTypePurpose CTP;
2439-
2440-
public:
2441-
Candidate(ConstraintSystem &cs, Expr *expr, Type ct = Type(),
2442-
ContextualTypePurpose ctp = ContextualTypePurpose::CTP_Unused)
2443-
: E(expr), DC(cs.DC), Allocator(cs.Allocator), CT(ct), CTP(ctp) {}
2444-
2445-
/// Return underlying expression.
2446-
Expr *getExpr() const { return E; }
2447-
2448-
/// Try to solve this candidate sub-expression
2449-
/// and re-write it's OSR domains afterwards.
2450-
///
2451-
/// \param shrunkExprs The set of expressions which
2452-
/// domains have been successfully shrunk so far.
2453-
///
2454-
/// \returns true on solver failure, false otherwise.
2455-
bool solve(llvm::SmallSetVector<OverloadSetRefExpr *, 4> &shrunkExprs);
2456-
2457-
/// Apply solutions found by solver as reduced OSR sets for
2458-
/// for current and all of it's sub-expressions.
2459-
///
2460-
/// \param solutions The solutions found by running solver on the
2461-
/// this candidate expression.
2462-
///
2463-
/// \param shrunkExprs The set of expressions which
2464-
/// domains have been successfully shrunk so far.
2465-
void applySolutions(
2466-
llvm::SmallVectorImpl<Solution> &solutions,
2467-
llvm::SmallSetVector<OverloadSetRefExpr *, 4> &shrunkExprs) const;
2468-
2469-
/// Check if attempt at solving of the candidate makes sense given
2470-
/// the current conditions - number of shrunk domains which is related
2471-
/// to the given candidate over the total number of disjunctions present.
2472-
static bool
2473-
isTooComplexGiven(ConstraintSystem *const cs,
2474-
llvm::SmallSetVector<OverloadSetRefExpr *, 4> &shrunkExprs) {
2475-
SmallVector<Constraint *, 8> disjunctions;
2476-
cs->collectDisjunctions(disjunctions);
2477-
2478-
unsigned unsolvedDisjunctions = disjunctions.size();
2479-
for (auto *disjunction : disjunctions) {
2480-
auto *locator = disjunction->getLocator();
2481-
if (!locator)
2482-
continue;
2483-
2484-
if (auto *OSR = getAsExpr<OverloadSetRefExpr>(locator->getAnchor())) {
2485-
if (shrunkExprs.count(OSR) > 0)
2486-
--unsolvedDisjunctions;
2487-
}
2488-
}
2489-
2490-
unsigned threshold =
2491-
cs->getASTContext().TypeCheckerOpts.SolverShrinkUnsolvedThreshold;
2492-
return unsolvedDisjunctions >= threshold;
2493-
}
2494-
};
2495-
24962431
/// Describes the current solver state.
24972432
struct SolverState {
24982433
SolverState(ConstraintSystem &cs,
@@ -3016,15 +2951,6 @@ class ConstraintSystem {
30162951
return nullptr;
30172952
}
30182953

3019-
TypeBase* getFavoredType(Expr *E) {
3020-
assert(E != nullptr);
3021-
return this->FavoredTypes[E];
3022-
}
3023-
void setFavoredType(Expr *E, TypeBase *T) {
3024-
assert(E != nullptr);
3025-
this->FavoredTypes[E] = T;
3026-
}
3027-
30282954
/// Set the type in our type map for the given node, and record the change
30292955
/// in the trail.
30302956
///
@@ -5280,19 +5206,11 @@ class ConstraintSystem {
52805206
/// \returns true if an error occurred, false otherwise.
52815207
bool solveSimplified(SmallVectorImpl<Solution> &solutions);
52825208

5283-
/// Find reduced domains of disjunction constraints for given
5284-
/// expression, this is achieved to solving individual sub-expressions
5285-
/// and combining resolving types. Such algorithm is called directional
5286-
/// path consistency because it goes from children to parents for all
5287-
/// related sub-expressions taking union of their domains.
5288-
///
5289-
/// \param expr The expression to find reductions for.
5290-
void shrink(Expr *expr);
5291-
52925209
/// Pick a disjunction from the InactiveConstraints list.
52935210
///
5294-
/// \returns The selected disjunction.
5295-
Constraint *selectDisjunction();
5211+
/// \returns The selected disjunction and a set of it's favored choices.
5212+
std::optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
5213+
selectDisjunction();
52965214

52975215
/// Pick a conjunction from the InactiveConstraints list.
52985216
///
@@ -5481,11 +5399,6 @@ class ConstraintSystem {
54815399
bool applySolutionToBody(TapExpr *tapExpr,
54825400
SyntacticElementTargetRewriter &rewriter);
54835401

5484-
/// Reorder the disjunctive clauses for a given expression to
5485-
/// increase the likelihood that a favored constraint will be successfully
5486-
/// resolved before any others.
5487-
void optimizeConstraints(Expr *e);
5488-
54895402
void startExpressionTimer(ExpressionTimer::AnchorType anchor);
54905403

54915404
/// Determine if we've already explored too many paths in an
@@ -6226,7 +6139,8 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
62266139
public:
62276140
using Element = DisjunctionChoice;
62286141

6229-
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction)
6142+
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction,
6143+
llvm::TinyPtrVector<Constraint *> &favorites)
62306144
: BindingProducer(cs, disjunction->shouldRememberChoice()
62316145
? disjunction->getLocator()
62326146
: nullptr),
@@ -6236,6 +6150,11 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
62366150
assert(disjunction->getKind() == ConstraintKind::Disjunction);
62376151
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
62386152

6153+
// Mark constraints as favored. This information
6154+
// is going to be used by partitioner.
6155+
for (auto *choice : favorites)
6156+
cs.favorConstraint(choice);
6157+
62396158
// Order and partition the disjunction choices.
62406159
partitionDisjunction(Ordering, PartitionBeginning);
62416160
}
@@ -6280,8 +6199,9 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
62806199
// Partition the choices in the disjunction into groups that we will
62816200
// iterate over in an order appropriate to attempt to stop before we
62826201
// have to visit all of the options.
6283-
void partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6284-
SmallVectorImpl<unsigned> &PartitionBeginning);
6202+
void
6203+
partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6204+
SmallVectorImpl<unsigned> &PartitionBeginning);
62856205

62866206
/// Partition the choices in the range \c first to \c last into groups and
62876207
/// order the groups in the best order to attempt based on the argument

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_swift_host_library(swiftSema STATIC
1313
CSStep.cpp
1414
CSTrail.cpp
1515
CSFix.cpp
16+
CSOptimizer.cpp
1617
CSDiagnostics.cpp
1718
CodeSynthesis.cpp
1819
CodeSynthesisDistributedActor.cpp

lib/Sema/CSBindings.cpp

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ using namespace swift;
3131
using namespace constraints;
3232
using namespace inference;
3333

34+
/// Check whether there exists a type that could be implicitly converted
35+
/// to a given type i.e. is the given type is Double or Optional<..> this
36+
/// function is going to return true because CGFloat could be converted
37+
/// to a Double and non-optional value could be injected into an optional.
38+
static bool hasConversions(Type);
39+
3440
static std::optional<Type> checkTypeOfBinding(TypeVariableType *typeVar,
3541
Type type);
3642

@@ -1209,7 +1215,31 @@ bool BindingSet::isViable(PotentialBinding &binding, bool isTransitive) {
12091215
if (!existingNTD || NTD != existingNTD)
12101216
continue;
12111217

1212-
// FIXME: What is going on here needs to be thoroughly re-evaluated.
1218+
// What is going on in this method needs to be thoroughly re-evaluated!
1219+
//
1220+
// This logic aims to skip dropping bindings if
1221+
// collection type has conversions i.e. in situations like:
1222+
//
1223+
// [$T1] conv $T2
1224+
// $T2 conv [(Int, String)]
1225+
// $T2.Element equal $T5.Element
1226+
//
1227+
// `$T1` could be bound to `(i: Int, v: String)` after
1228+
// `$T2` is bound to `[(Int, String)]` which is is a problem
1229+
// because it means that `$T2` was attempted to early
1230+
// before the solver had a chance to discover all viable
1231+
// bindings.
1232+
//
1233+
// Let's say existing binding is `[(Int, String)]` and
1234+
// relation is "exact", in this case there is no point
1235+
// tracking `[$T1]` because upcasts are only allowed for
1236+
// subtype and other conversions.
1237+
if (existing->Kind != AllowedBindingKind::Exact) {
1238+
if (existingType->isKnownStdlibCollectionType() &&
1239+
hasConversions(existingType)) {
1240+
continue;
1241+
}
1242+
}
12131243

12141244
// If new type has a type variable it shouldn't
12151245
// be considered viable.
@@ -2417,17 +2447,35 @@ bool TypeVarBindingProducer::computeNext() {
24172447
if (binding.Kind == BindingKind::Subtypes || CS.shouldAttemptFixes()) {
24182448
// If we were unsuccessful solving for T?, try solving for T.
24192449
if (auto objTy = type->getOptionalObjectType()) {
2420-
// If T is a type variable, only attempt this if both the
2421-
// type variable we are trying bindings for, and the type
2422-
// variable we will attempt to bind, both have the same
2423-
// polarity with respect to being able to bind lvalues.
2424-
if (auto otherTypeVar = objTy->getAs<TypeVariableType>()) {
2425-
if (TypeVar->getImpl().canBindToLValue() ==
2426-
otherTypeVar->getImpl().canBindToLValue()) {
2427-
addNewBinding(binding.withSameSource(objTy, binding.Kind));
2450+
// TODO: This could be generalized in the future to cover all patterns
2451+
// that have an intermediate type variable in subtype/conversion chain.
2452+
//
2453+
// Let's not perform $T? -> $T for closure result types to avoid having
2454+
// to re-discover solutions that differ only in location of optional
2455+
// injection.
2456+
//
2457+
// The pattern with such type variables is:
2458+
//
2459+
// $T_body <conv/subtype> $T_result <conv/subtype> $T_contextual_result
2460+
//
2461+
// When $T_contextual_result is Optional<$U>, the optional injection
2462+
// can either happen from $T_body or from $T_result (if `return`
2463+
// expression is non-optional), if we allow both the solver would
2464+
// find two solutions that differ only in location of optional
2465+
// injection.
2466+
if (!TypeVar->getImpl().isClosureResultType()) {
2467+
// If T is a type variable, only attempt this if both the
2468+
// type variable we are trying bindings for, and the type
2469+
// variable we will attempt to bind, both have the same
2470+
// polarity with respect to being able to bind lvalues.
2471+
if (auto otherTypeVar = objTy->getAs<TypeVariableType>()) {
2472+
if (TypeVar->getImpl().canBindToLValue() ==
2473+
otherTypeVar->getImpl().canBindToLValue()) {
2474+
addNewBinding(binding.withType(objTy));
2475+
}
2476+
} else {
2477+
addNewBinding(binding.withType(objTy));
24282478
}
2429-
} else {
2430-
addNewBinding(binding.withSameSource(objTy, binding.Kind));
24312479
}
24322480
}
24332481
}

0 commit comments

Comments
 (0)