Skip to content

Commit c57c403

Browse files
committed
[ConstraintSystem] Record parameter bindings in solutions (NFC)
This saves us from needing to re-match args to params in CSApply and is also useful for a forthcoming change migrating code completion in argument position to use the solver-based typeCheckForCodeCompletion api. rdar://76581093
1 parent 59a218a commit c57c403

File tree

8 files changed

+138
-116
lines changed

8 files changed

+138
-116
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,43 @@ class SolutionApplicationTargetsKey {
11111111
}
11121112
};
11131113

1114+
/// Describes the arguments to which a parameter binds.
1115+
/// FIXME: This is an awful data structure. We want the equivalent of a
1116+
/// TinyPtrVector for unsigned values.
1117+
using ParamBinding = SmallVector<unsigned, 1>;
1118+
1119+
/// The result of calling matchCallArguments().
1120+
struct MatchCallArgumentResult {
1121+
/// The direction of trailing closure matching that was performed.
1122+
TrailingClosureMatching trailingClosureMatching;
1123+
1124+
/// The parameter bindings determined by the match.
1125+
SmallVector<ParamBinding, 4> parameterBindings;
1126+
1127+
/// When present, the forward and backward scans each produced a result,
1128+
/// and the parameter bindings are different. The primary result will be
1129+
/// forwarding, and this represents the backward binding.
1130+
Optional<SmallVector<ParamBinding, 4>> backwardParameterBindings;
1131+
1132+
friend bool operator==(const MatchCallArgumentResult &lhs,
1133+
const MatchCallArgumentResult &rhs) {
1134+
if (lhs.trailingClosureMatching != rhs.trailingClosureMatching)
1135+
return false;
1136+
if (lhs.parameterBindings != rhs.parameterBindings)
1137+
return false;
1138+
return lhs.backwardParameterBindings == rhs.backwardParameterBindings;
1139+
}
1140+
1141+
/// Generate a result that maps the provided number of arguments to the same
1142+
/// number of parameters via forward match.
1143+
static MatchCallArgumentResult forArity(unsigned argCount) {
1144+
SmallVector<ParamBinding, 4> Bindings;
1145+
for (unsigned i : range(argCount))
1146+
Bindings.push_back({i});
1147+
return {TrailingClosureMatching::Forward, Bindings, None};
1148+
}
1149+
};
1150+
11141151
/// A complete solution to a constraint system.
11151152
///
11161153
/// A solution to a constraint system consists of type variable bindings to
@@ -1159,9 +1196,9 @@ class Solution {
11591196
llvm::SmallVector<ConstraintFix *, 4> Fixes;
11601197

11611198
/// For locators associated with call expressions, the trailing closure
1162-
/// matching rule that was applied.
1163-
llvm::SmallMapVector<ConstraintLocator*, TrailingClosureMatching, 4>
1164-
trailingClosureMatchingChoices;
1199+
/// matching rule and parameter bindings that were applied.
1200+
llvm::SmallMapVector<ConstraintLocator *, MatchCallArgumentResult, 4>
1201+
argumentMatchingChoices;
11651202

11661203
/// The set of disjunction choices used to arrive at this solution,
11671204
/// which informs constraint application.
@@ -1203,6 +1240,10 @@ class Solution {
12031240
/// A map from argument expressions to their applied property wrapper expressions.
12041241
llvm::MapVector<ASTNode, SmallVector<AppliedPropertyWrapper, 2>> appliedPropertyWrappers;
12051242

1243+
/// Record a new argument matching choice for given locator that maps a
1244+
/// single argument to a single parameter.
1245+
void recordSingleArgMatchingChoice(ConstraintLocator *locator);
1246+
12061247
/// Simplify the given type by substituting all occurrences of
12071248
/// type variables for their fixed types.
12081249
Type simplifyType(Type type) const;
@@ -2210,9 +2251,9 @@ class ConstraintSystem {
22102251
AppliedDisjunctions;
22112252

22122253
/// For locators associated with call expressions, the trailing closure
2213-
/// matching rule that was applied.
2214-
std::vector<std::pair<ConstraintLocator*, TrailingClosureMatching>>
2215-
trailingClosureMatchingChoices;
2254+
/// matching rule and parameter bindings that were applied.
2255+
std::vector<std::pair<ConstraintLocator *, MatchCallArgumentResult>>
2256+
argumentMatchingChoices;
22162257

22172258
/// The set of implicit value conversions performed by the solver on
22182259
/// a current path to reach a solution.
@@ -2709,8 +2750,8 @@ class ConstraintSystem {
27092750
/// The length of \c AppliedDisjunctions.
27102751
unsigned numAppliedDisjunctions;
27112752

2712-
/// The length of \c trailingClosureMatchingChoices;
2713-
unsigned numTrailingClosureMatchingChoices;
2753+
/// The length of \c argumentMatchingChoices.
2754+
unsigned numArgumentMatchingChoices;
27142755

27152756
/// The length of \c OpenedTypes.
27162757
unsigned numOpenedTypes;
@@ -3226,11 +3267,8 @@ class ConstraintSystem {
32263267

32273268
void recordPotentialHole(Type type);
32283269

3229-
void recordTrailingClosureMatch(
3230-
ConstraintLocator *locator,
3231-
TrailingClosureMatching trailingClosureMatch) {
3232-
trailingClosureMatchingChoices.push_back({locator, trailingClosureMatch});
3233-
}
3270+
void recordMatchCallArgumentResult(ConstraintLocator *locator,
3271+
MatchCallArgumentResult result);
32343272

32353273
/// Walk a closure AST to determine its effects.
32363274
///
@@ -5086,11 +5124,6 @@ static inline bool computeTupleShuffle(TupleType *fromTuple,
50865124
sources);
50875125
}
50885126

5089-
/// Describes the arguments to which a parameter binds.
5090-
/// FIXME: This is an awful data structure. We want the equivalent of a
5091-
/// TinyPtrVector for unsigned values.
5092-
using ParamBinding = SmallVector<unsigned, 1>;
5093-
50945127
/// Class used as the base for listeners to the \c matchCallArguments process.
50955128
///
50965129
/// By default, none of the callbacks do anything.
@@ -5158,20 +5191,6 @@ class MatchCallArgumentListener {
51585191
virtual bool relabelArguments(ArrayRef<Identifier> newNames);
51595192
};
51605193

5161-
/// The result of calling matchCallArguments().
5162-
struct MatchCallArgumentResult {
5163-
/// The direction of trailing closure matching that was performed.
5164-
TrailingClosureMatching trailingClosureMatching;
5165-
5166-
/// The parameter bindings determined by the match.
5167-
SmallVector<ParamBinding, 4> parameterBindings;
5168-
5169-
/// When present, the forward and backward scans each produced a result,
5170-
/// and the parameter bindings are different. The primary result will be
5171-
/// forwarding, and this represents the backward binding.
5172-
Optional<SmallVector<ParamBinding, 4>> backwardParameterBindings;
5173-
};
5174-
51755194
/// Match the call arguments (as described by the given argument type) to
51765195
/// the parameters (as described by the given parameter type).
51775196
///

lib/Sema/CSApply.cpp

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,16 +3099,16 @@ namespace {
30993099
DeclNameLoc nameLoc, bool implicit,
31003100
ConstraintLocator *ctorLocator,
31013101
SelectedOverload overload) {
3102+
auto locator = cs.getConstraintLocator(expr);
31023103
auto choice = overload.choice;
31033104
assert(choice.getKind() != OverloadChoiceKind::DeclViaDynamic);
31043105
auto *ctor = cast<ConstructorDecl>(choice.getDecl());
31053106

31063107
// If the subexpression is a metatype, build a direct reference to the
31073108
// constructor.
31083109
if (cs.getType(base)->is<AnyMetatypeType>()) {
3109-
return buildMemberRef(
3110-
base, dotLoc, overload, nameLoc, cs.getConstraintLocator(expr),
3111-
ctorLocator, implicit, AccessSemantics::Ordinary);
3110+
return buildMemberRef(base, dotLoc, overload, nameLoc, locator,
3111+
ctorLocator, implicit, AccessSemantics::Ordinary);
31123112
}
31133113

31143114
// The subexpression must be either 'self' or 'super'.
@@ -3158,8 +3158,7 @@ namespace {
31583158
auto *call = new (cs.getASTContext()) DotSyntaxCallExpr(ctorRef, dotLoc,
31593159
base);
31603160

3161-
return finishApply(call, cs.getType(expr), cs.getConstraintLocator(expr),
3162-
ctorLocator);
3161+
return finishApply(call, cs.getType(expr), locator, ctorLocator);
31633162
}
31643163

31653164
/// Give the deprecation warning for referring to a global function
@@ -4841,19 +4840,19 @@ namespace {
48414840
}
48424841

48434842
auto kind = origComponent.getKind();
4844-
auto locator = cs.getConstraintLocator(
4845-
E, LocatorPathElt::KeyPathComponent(i));
4843+
auto componentLocator =
4844+
cs.getConstraintLocator(E, LocatorPathElt::KeyPathComponent(i));
48464845

4847-
// Adjust the locator such that it includes any additional elements to
4848-
// point to the component's callee, e.g a SubscriptMember for a
4849-
// subscript component.
4850-
locator = cs.getCalleeLocator(locator);
4846+
// Get a locator such that it includes any additional elements to point
4847+
// to the component's callee, e.g a SubscriptMember for a subscript
4848+
// component.
4849+
auto calleeLoc = cs.getCalleeLocator(componentLocator);
48514850

48524851
bool isDynamicMember = false;
48534852
// If this is an unresolved link, make sure we resolved it.
48544853
if (kind == KeyPathExpr::Component::Kind::UnresolvedProperty ||
48554854
kind == KeyPathExpr::Component::Kind::UnresolvedSubscript) {
4856-
auto foundDecl = solution.getOverloadChoiceIfAvailable(locator);
4855+
auto foundDecl = solution.getOverloadChoiceIfAvailable(calleeLoc);
48574856
if (!foundDecl) {
48584857
// If we couldn't resolve the component, leave it alone.
48594858
resolvedComponents.push_back(origComponent);
@@ -4876,9 +4875,9 @@ namespace {
48764875

48774876
switch (kind) {
48784877
case KeyPathExpr::Component::Kind::UnresolvedProperty: {
4879-
buildKeyPathPropertyComponent(solution.getOverloadChoice(locator),
4880-
origComponent.getLoc(),
4881-
locator, resolvedComponents);
4878+
buildKeyPathPropertyComponent(solution.getOverloadChoice(calleeLoc),
4879+
origComponent.getLoc(), calleeLoc,
4880+
resolvedComponents);
48824881
break;
48834882
}
48844883
case KeyPathExpr::Component::Kind::UnresolvedSubscript: {
@@ -4887,9 +4886,9 @@ namespace {
48874886
subscriptLabels = origComponent.getSubscriptLabels();
48884887

48894888
buildKeyPathSubscriptComponent(
4890-
solution.getOverloadChoice(locator),
4891-
origComponent.getLoc(), origComponent.getIndexExpr(),
4892-
subscriptLabels, locator, resolvedComponents);
4889+
solution.getOverloadChoice(calleeLoc), origComponent.getLoc(),
4890+
origComponent.getIndexExpr(), subscriptLabels, componentLocator,
4891+
resolvedComponents);
48934892
break;
48944893
}
48954894
case KeyPathExpr::Component::Kind::OptionalChain: {
@@ -5138,9 +5137,10 @@ namespace {
51385137
SmallVectorImpl<KeyPathExpr::Component> &components) {
51395138
auto subscript = cast<SubscriptDecl>(overload.choice.getDecl());
51405139
assert(!subscript->isGetterMutating());
5140+
auto memberLoc = cs.getCalleeLocator(locator);
51415141

51425142
// Compute substitutions to refer to the member.
5143-
auto ref = resolveConcreteDeclRef(subscript, locator);
5143+
auto ref = resolveConcreteDeclRef(subscript, memberLoc);
51445144

51455145
// If this is a @dynamicMemberLookup reference to resolve a property
51465146
// through the subscript(dynamicMember:) member, restore the
@@ -5159,23 +5159,27 @@ namespace {
51595159
if (overload.choice.getKind() ==
51605160
OverloadChoiceKind::KeyPathDynamicMemberLookup) {
51615161
indexExpr = buildKeyPathDynamicMemberIndexExpr(
5162-
indexType->castTo<BoundGenericType>(), componentLoc, locator);
5162+
indexType->castTo<BoundGenericType>(), componentLoc, memberLoc);
51635163
} else {
51645164
auto fieldName = overload.choice.getName().getBaseIdentifier().str();
51655165
indexExpr = buildDynamicMemberLookupIndexExpr(fieldName, componentLoc,
51665166
indexType);
51675167
}
5168+
// Record the implicit subscript expr's parameter bindings and matching
5169+
// direction as `coerceCallArguments` requires them.
5170+
solution.recordSingleArgMatchingChoice(locator);
51685171
}
51695172

51705173
auto subscriptType =
51715174
simplifyType(overload.openedType)->castTo<AnyFunctionType>();
51725175
auto resolvedTy = subscriptType->getResult();
51735176

51745177
// Coerce the indices to the type the subscript expects.
5175-
auto *newIndexExpr =
5176-
coerceCallArguments(indexExpr, subscriptType, ref,
5177-
/*applyExpr*/ nullptr, labels,
5178-
locator, /*appliedPropertyWrappers*/ {});
5178+
auto *newIndexExpr = coerceCallArguments(
5179+
indexExpr, subscriptType, ref,
5180+
/*applyExpr*/ nullptr, labels,
5181+
cs.getConstraintLocator(locator, ConstraintLocator::ApplyArgument),
5182+
/*appliedPropertyWrappers*/ {});
51795183

51805184
// We need to be able to hash the captured index values in order for
51815185
// KeyPath itself to be hashable, so check that all of the subscript
@@ -5761,6 +5765,8 @@ Expr *ExprRewriter::coerceCallArguments(
57615765
ArrayRef<Identifier> argLabels,
57625766
ConstraintLocatorBuilder locator,
57635767
ArrayRef<AppliedPropertyWrapper> appliedPropertyWrappers) {
5768+
assert(locator.last() && locator.last()->is<LocatorPathElt::ApplyArgument>());
5769+
57645770
auto &ctx = getConstraintSystem().getASTContext();
57655771
auto params = funcType->getParams();
57665772
unsigned appliedWrapperIndex = 0;
@@ -5773,11 +5779,6 @@ Expr *ExprRewriter::coerceCallArguments(
57735779
LocatorPathElt::ApplyArgToParam(argIdx, paramIdx, flags));
57745780
};
57755781

5776-
bool matchCanFail =
5777-
llvm::any_of(params, [](const AnyFunctionType::Param &param) {
5778-
return param.getPlainType()->hasUnresolvedType();
5779-
});
5780-
57815782
// Determine whether this application has curried self.
57825783
bool skipCurriedSelf = apply ? hasCurriedSelf(cs, callee, apply) : true;
57835784
// Determine the parameter bindings.
@@ -5812,34 +5813,14 @@ Expr *ExprRewriter::coerceCallArguments(
58125813
// Apply labels to arguments.
58135814
AnyFunctionType::relabelParams(args, argLabels);
58145815

5815-
MatchCallArgumentListener listener;
58165816
auto unlabeledTrailingClosureIndex =
58175817
arg->getUnlabeledTrailingClosureIndexOfPackedArgument();
58185818

5819-
// Determine the trailing closure matching rule that was applied. This
5820-
// is only relevant for explicit calls and subscripts.
5821-
auto trailingClosureMatching = TrailingClosureMatching::Forward;
5822-
{
5823-
SmallVector<LocatorPathElt, 4> path;
5824-
auto anchor = locator.getLocatorParts(path);
5825-
if (!path.empty() && path.back().is<LocatorPathElt::ApplyArgument>() &&
5826-
!anchor.isExpr(ExprKind::UnresolvedDot)) {
5827-
auto locatorPtr = cs.getConstraintLocator(locator);
5828-
assert(solution.trailingClosureMatchingChoices.count(locatorPtr) == 1);
5829-
trailingClosureMatching = solution.trailingClosureMatchingChoices.find(
5830-
locatorPtr)->second;
5831-
}
5832-
}
5833-
5834-
auto callArgumentMatch = constraints::matchCallArguments(
5835-
args, params, paramInfo, unlabeledTrailingClosureIndex,
5836-
/*allowFixes=*/false, listener, trailingClosureMatching);
5837-
5838-
assert((matchCanFail || callArgumentMatch) &&
5839-
"Call arguments did not match up?");
5840-
(void)matchCanFail;
5841-
5842-
auto parameterBindings = std::move(callArgumentMatch->parameterBindings);
5819+
// Determine the parameter bindings that were applied.
5820+
auto *locatorPtr = cs.getConstraintLocator(locator);
5821+
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
5822+
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
5823+
->second.parameterBindings;
58435824

58445825
// We should either have parentheses or a tuple.
58455826
auto *argTuple = dyn_cast<TupleExpr>(arg);
@@ -6849,12 +6830,11 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
68496830
ConstraintLocator::ConstructorMember}));
68506831

68516832
solution.overloadChoices.insert({memberLoc, overload});
6852-
solution.trailingClosureMatchingChoices.insert(
6853-
{cs.getConstraintLocator(callLocator,
6854-
ConstraintLocator::ApplyArgument),
6855-
TrailingClosureMatching::Forward});
68566833
}
68576834

6835+
// Record the implicit call's parameter bindings and match direction.
6836+
solution.recordSingleArgMatchingChoice(callLocator);
6837+
68586838
finishApply(implicitInit, toType, callLocator, callLocator);
68596839
return implicitInit;
68606840
}

lib/Sema/CSGen.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,14 @@ namespace {
14771477
}
14781478

14791479
Type visitUnresolvedDotExpr(UnresolvedDotExpr *expr) {
1480+
// UnresolvedDot applies the base to remove a single curry level from a
1481+
// member reference without using an applicable function constraint so
1482+
// we record the call argument matching here so it can be found later when
1483+
// a solution is applied to the AST.
1484+
CS.recordMatchCallArgumentResult(
1485+
CS.getConstraintLocator(expr, ConstraintLocator::ApplyArgument),
1486+
MatchCallArgumentResult::forArity(1));
1487+
14801488
// If this is Builtin.type_join*, just return any type and move
14811489
// on since we're going to discard this, and creating any type
14821490
// variables for the reference will cause problems.

lib/Sema/CSSimplify.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,9 +1309,9 @@ ConstraintSystem::TypeMatchResult constraints::matchCallArguments(
13091309
}
13101310

13111311
selectedTrailingMatching = callArgumentMatch->trailingClosureMatching;
1312-
// Record the direction of matching used for this call.
1313-
cs.recordTrailingClosureMatch(cs.getConstraintLocator(locator),
1314-
selectedTrailingMatching);
1312+
// Record the matching direction and parameter bindings used for this call.
1313+
cs.recordMatchCallArgumentResult(cs.getConstraintLocator(locator),
1314+
*callArgumentMatch);
13151315

13161316
// If there was a disjunction because both forward and backward were
13171317
// possible, increase the score for forward matches to bias toward the
@@ -9716,10 +9716,10 @@ ConstraintSystem::simplifyApplicableFnConstraint(
97169716
// have an explicit inout argument.
97179717
if (type1.getPointer() == desugar2) {
97189718
if (!isOperator || !hasInOut()) {
9719-
recordTrailingClosureMatch(
9719+
recordMatchCallArgumentResult(
97209720
getConstraintLocator(
97219721
outerLocator.withPathElement(ConstraintLocator::ApplyArgument)),
9722-
TrailingClosureMatching::Forward);
9722+
MatchCallArgumentResult::forArity(func1->getNumParams()));
97239723
return SolutionKind::Solved;
97249724
}
97259725
}
@@ -10864,6 +10864,12 @@ void ConstraintSystem::recordPotentialHole(Type type) {
1086410864
});
1086510865
}
1086610866

10867+
void ConstraintSystem::recordMatchCallArgumentResult(
10868+
ConstraintLocator *locator, MatchCallArgumentResult result) {
10869+
assert(locator->isLastElement<LocatorPathElt::ApplyArgument>());
10870+
argumentMatchingChoices.push_back({locator, result});
10871+
}
10872+
1086710873
ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint(
1086810874
ConstraintFix *fix, Type type1, Type type2, ConstraintKind matchKind,
1086910875
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {

0 commit comments

Comments
 (0)