Skip to content

Commit f75f5fe

Browse files
authored
Merge pull request #36879 from nathawes/track-match-call-result
[ConstraintSystem] Record parameter bindings in solutions (NFC)
2 parents e307650 + c57c403 commit f75f5fe

File tree

8 files changed

+138
-116
lines changed

8 files changed

+138
-116
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,43 @@ class SolutionApplicationTargetsKey {
11111111
}
11121112
};
11131113

1114+
/// Describes the arguments to which a parameter binds.
1115+
/// FIXME: This is an awful data structure. We want the equivalent of a
1116+
/// TinyPtrVector for unsigned values.
1117+
using ParamBinding = SmallVector<unsigned, 1>;
1118+
1119+
/// The result of calling matchCallArguments().
1120+
struct MatchCallArgumentResult {
1121+
/// The direction of trailing closure matching that was performed.
1122+
TrailingClosureMatching trailingClosureMatching;
1123+
1124+
/// The parameter bindings determined by the match.
1125+
SmallVector<ParamBinding, 4> parameterBindings;
1126+
1127+
/// When present, the forward and backward scans each produced a result,
1128+
/// and the parameter bindings are different. The primary result will be
1129+
/// forwarding, and this represents the backward binding.
1130+
Optional<SmallVector<ParamBinding, 4>> backwardParameterBindings;
1131+
1132+
friend bool operator==(const MatchCallArgumentResult &lhs,
1133+
const MatchCallArgumentResult &rhs) {
1134+
if (lhs.trailingClosureMatching != rhs.trailingClosureMatching)
1135+
return false;
1136+
if (lhs.parameterBindings != rhs.parameterBindings)
1137+
return false;
1138+
return lhs.backwardParameterBindings == rhs.backwardParameterBindings;
1139+
}
1140+
1141+
/// Generate a result that maps the provided number of arguments to the same
1142+
/// number of parameters via forward match.
1143+
static MatchCallArgumentResult forArity(unsigned argCount) {
1144+
SmallVector<ParamBinding, 4> Bindings;
1145+
for (unsigned i : range(argCount))
1146+
Bindings.push_back({i});
1147+
return {TrailingClosureMatching::Forward, Bindings, None};
1148+
}
1149+
};
1150+
11141151
/// A complete solution to a constraint system.
11151152
///
11161153
/// A solution to a constraint system consists of type variable bindings to
@@ -1159,9 +1196,9 @@ class Solution {
11591196
llvm::SmallVector<ConstraintFix *, 4> Fixes;
11601197

11611198
/// For locators associated with call expressions, the trailing closure
1162-
/// matching rule that was applied.
1163-
llvm::SmallMapVector<ConstraintLocator*, TrailingClosureMatching, 4>
1164-
trailingClosureMatchingChoices;
1199+
/// matching rule and parameter bindings that were applied.
1200+
llvm::SmallMapVector<ConstraintLocator *, MatchCallArgumentResult, 4>
1201+
argumentMatchingChoices;
11651202

11661203
/// The set of disjunction choices used to arrive at this solution,
11671204
/// which informs constraint application.
@@ -1203,6 +1240,10 @@ class Solution {
12031240
/// A map from argument expressions to their applied property wrapper expressions.
12041241
llvm::MapVector<ASTNode, SmallVector<AppliedPropertyWrapper, 2>> appliedPropertyWrappers;
12051242

1243+
/// Record a new argument matching choice for given locator that maps a
1244+
/// single argument to a single parameter.
1245+
void recordSingleArgMatchingChoice(ConstraintLocator *locator);
1246+
12061247
/// Simplify the given type by substituting all occurrences of
12071248
/// type variables for their fixed types.
12081249
Type simplifyType(Type type) const;
@@ -2210,9 +2251,9 @@ class ConstraintSystem {
22102251
AppliedDisjunctions;
22112252

22122253
/// For locators associated with call expressions, the trailing closure
2213-
/// matching rule that was applied.
2214-
std::vector<std::pair<ConstraintLocator*, TrailingClosureMatching>>
2215-
trailingClosureMatchingChoices;
2254+
/// matching rule and parameter bindings that were applied.
2255+
std::vector<std::pair<ConstraintLocator *, MatchCallArgumentResult>>
2256+
argumentMatchingChoices;
22162257

22172258
/// The set of implicit value conversions performed by the solver on
22182259
/// a current path to reach a solution.
@@ -2709,8 +2750,8 @@ class ConstraintSystem {
27092750
/// The length of \c AppliedDisjunctions.
27102751
unsigned numAppliedDisjunctions;
27112752

2712-
/// The length of \c trailingClosureMatchingChoices;
2713-
unsigned numTrailingClosureMatchingChoices;
2753+
/// The length of \c argumentMatchingChoices.
2754+
unsigned numArgumentMatchingChoices;
27142755

27152756
/// The length of \c OpenedTypes.
27162757
unsigned numOpenedTypes;
@@ -3226,11 +3267,8 @@ class ConstraintSystem {
32263267

32273268
void recordPotentialHole(Type type);
32283269

3229-
void recordTrailingClosureMatch(
3230-
ConstraintLocator *locator,
3231-
TrailingClosureMatching trailingClosureMatch) {
3232-
trailingClosureMatchingChoices.push_back({locator, trailingClosureMatch});
3233-
}
3270+
void recordMatchCallArgumentResult(ConstraintLocator *locator,
3271+
MatchCallArgumentResult result);
32343272

32353273
/// Walk a closure AST to determine its effects.
32363274
///
@@ -5071,11 +5109,6 @@ static inline bool computeTupleShuffle(TupleType *fromTuple,
50715109
sources);
50725110
}
50735111

5074-
/// Describes the arguments to which a parameter binds.
5075-
/// FIXME: This is an awful data structure. We want the equivalent of a
5076-
/// TinyPtrVector for unsigned values.
5077-
using ParamBinding = SmallVector<unsigned, 1>;
5078-
50795112
/// Class used as the base for listeners to the \c matchCallArguments process.
50805113
///
50815114
/// By default, none of the callbacks do anything.
@@ -5143,20 +5176,6 @@ class MatchCallArgumentListener {
51435176
virtual bool relabelArguments(ArrayRef<Identifier> newNames);
51445177
};
51455178

5146-
/// The result of calling matchCallArguments().
5147-
struct MatchCallArgumentResult {
5148-
/// The direction of trailing closure matching that was performed.
5149-
TrailingClosureMatching trailingClosureMatching;
5150-
5151-
/// The parameter bindings determined by the match.
5152-
SmallVector<ParamBinding, 4> parameterBindings;
5153-
5154-
/// When present, the forward and backward scans each produced a result,
5155-
/// and the parameter bindings are different. The primary result will be
5156-
/// forwarding, and this represents the backward binding.
5157-
Optional<SmallVector<ParamBinding, 4>> backwardParameterBindings;
5158-
};
5159-
51605179
/// Match the call arguments (as described by the given argument type) to
51615180
/// the parameters (as described by the given parameter type).
51625181
///

lib/Sema/CSApply.cpp

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3101,16 +3101,16 @@ namespace {
31013101
DeclNameLoc nameLoc, bool implicit,
31023102
ConstraintLocator *ctorLocator,
31033103
SelectedOverload overload) {
3104+
auto locator = cs.getConstraintLocator(expr);
31043105
auto choice = overload.choice;
31053106
assert(choice.getKind() != OverloadChoiceKind::DeclViaDynamic);
31063107
auto *ctor = cast<ConstructorDecl>(choice.getDecl());
31073108

31083109
// If the subexpression is a metatype, build a direct reference to the
31093110
// constructor.
31103111
if (cs.getType(base)->is<AnyMetatypeType>()) {
3111-
return buildMemberRef(
3112-
base, dotLoc, overload, nameLoc, cs.getConstraintLocator(expr),
3113-
ctorLocator, implicit, AccessSemantics::Ordinary);
3112+
return buildMemberRef(base, dotLoc, overload, nameLoc, locator,
3113+
ctorLocator, implicit, AccessSemantics::Ordinary);
31143114
}
31153115

31163116
// The subexpression must be either 'self' or 'super'.
@@ -3160,8 +3160,7 @@ namespace {
31603160
auto *call = new (cs.getASTContext()) DotSyntaxCallExpr(ctorRef, dotLoc,
31613161
base);
31623162

3163-
return finishApply(call, cs.getType(expr), cs.getConstraintLocator(expr),
3164-
ctorLocator);
3163+
return finishApply(call, cs.getType(expr), locator, ctorLocator);
31653164
}
31663165

31673166
/// Give the deprecation warning for referring to a global function
@@ -4843,19 +4842,19 @@ namespace {
48434842
}
48444843

48454844
auto kind = origComponent.getKind();
4846-
auto locator = cs.getConstraintLocator(
4847-
E, LocatorPathElt::KeyPathComponent(i));
4845+
auto componentLocator =
4846+
cs.getConstraintLocator(E, LocatorPathElt::KeyPathComponent(i));
48484847

4849-
// Adjust the locator such that it includes any additional elements to
4850-
// point to the component's callee, e.g a SubscriptMember for a
4851-
// subscript component.
4852-
locator = cs.getCalleeLocator(locator);
4848+
// Get a locator such that it includes any additional elements to point
4849+
// to the component's callee, e.g a SubscriptMember for a subscript
4850+
// component.
4851+
auto calleeLoc = cs.getCalleeLocator(componentLocator);
48534852

48544853
bool isDynamicMember = false;
48554854
// If this is an unresolved link, make sure we resolved it.
48564855
if (kind == KeyPathExpr::Component::Kind::UnresolvedProperty ||
48574856
kind == KeyPathExpr::Component::Kind::UnresolvedSubscript) {
4858-
auto foundDecl = solution.getOverloadChoiceIfAvailable(locator);
4857+
auto foundDecl = solution.getOverloadChoiceIfAvailable(calleeLoc);
48594858
if (!foundDecl) {
48604859
// If we couldn't resolve the component, leave it alone.
48614860
resolvedComponents.push_back(origComponent);
@@ -4878,9 +4877,9 @@ namespace {
48784877

48794878
switch (kind) {
48804879
case KeyPathExpr::Component::Kind::UnresolvedProperty: {
4881-
buildKeyPathPropertyComponent(solution.getOverloadChoice(locator),
4882-
origComponent.getLoc(),
4883-
locator, resolvedComponents);
4880+
buildKeyPathPropertyComponent(solution.getOverloadChoice(calleeLoc),
4881+
origComponent.getLoc(), calleeLoc,
4882+
resolvedComponents);
48844883
break;
48854884
}
48864885
case KeyPathExpr::Component::Kind::UnresolvedSubscript: {
@@ -4889,9 +4888,9 @@ namespace {
48894888
subscriptLabels = origComponent.getSubscriptLabels();
48904889

48914890
buildKeyPathSubscriptComponent(
4892-
solution.getOverloadChoice(locator),
4893-
origComponent.getLoc(), origComponent.getIndexExpr(),
4894-
subscriptLabels, locator, resolvedComponents);
4891+
solution.getOverloadChoice(calleeLoc), origComponent.getLoc(),
4892+
origComponent.getIndexExpr(), subscriptLabels, componentLocator,
4893+
resolvedComponents);
48954894
break;
48964895
}
48974896
case KeyPathExpr::Component::Kind::OptionalChain: {
@@ -5140,9 +5139,10 @@ namespace {
51405139
SmallVectorImpl<KeyPathExpr::Component> &components) {
51415140
auto subscript = cast<SubscriptDecl>(overload.choice.getDecl());
51425141
assert(!subscript->isGetterMutating());
5142+
auto memberLoc = cs.getCalleeLocator(locator);
51435143

51445144
// Compute substitutions to refer to the member.
5145-
auto ref = resolveConcreteDeclRef(subscript, locator);
5145+
auto ref = resolveConcreteDeclRef(subscript, memberLoc);
51465146

51475147
// If this is a @dynamicMemberLookup reference to resolve a property
51485148
// through the subscript(dynamicMember:) member, restore the
@@ -5161,23 +5161,27 @@ namespace {
51615161
if (overload.choice.getKind() ==
51625162
OverloadChoiceKind::KeyPathDynamicMemberLookup) {
51635163
indexExpr = buildKeyPathDynamicMemberIndexExpr(
5164-
indexType->castTo<BoundGenericType>(), componentLoc, locator);
5164+
indexType->castTo<BoundGenericType>(), componentLoc, memberLoc);
51655165
} else {
51665166
auto fieldName = overload.choice.getName().getBaseIdentifier().str();
51675167
indexExpr = buildDynamicMemberLookupIndexExpr(fieldName, componentLoc,
51685168
indexType);
51695169
}
5170+
// Record the implicit subscript expr's parameter bindings and matching
5171+
// direction as `coerceCallArguments` requires them.
5172+
solution.recordSingleArgMatchingChoice(locator);
51705173
}
51715174

51725175
auto subscriptType =
51735176
simplifyType(overload.openedType)->castTo<AnyFunctionType>();
51745177
auto resolvedTy = subscriptType->getResult();
51755178

51765179
// Coerce the indices to the type the subscript expects.
5177-
auto *newIndexExpr =
5178-
coerceCallArguments(indexExpr, subscriptType, ref,
5179-
/*applyExpr*/ nullptr, labels,
5180-
locator, /*appliedPropertyWrappers*/ {});
5180+
auto *newIndexExpr = coerceCallArguments(
5181+
indexExpr, subscriptType, ref,
5182+
/*applyExpr*/ nullptr, labels,
5183+
cs.getConstraintLocator(locator, ConstraintLocator::ApplyArgument),
5184+
/*appliedPropertyWrappers*/ {});
51815185

51825186
// We need to be able to hash the captured index values in order for
51835187
// KeyPath itself to be hashable, so check that all of the subscript
@@ -5763,6 +5767,8 @@ Expr *ExprRewriter::coerceCallArguments(
57635767
ArrayRef<Identifier> argLabels,
57645768
ConstraintLocatorBuilder locator,
57655769
ArrayRef<AppliedPropertyWrapper> appliedPropertyWrappers) {
5770+
assert(locator.last() && locator.last()->is<LocatorPathElt::ApplyArgument>());
5771+
57665772
auto &ctx = getConstraintSystem().getASTContext();
57675773
auto params = funcType->getParams();
57685774
unsigned appliedWrapperIndex = 0;
@@ -5775,11 +5781,6 @@ Expr *ExprRewriter::coerceCallArguments(
57755781
LocatorPathElt::ApplyArgToParam(argIdx, paramIdx, flags));
57765782
};
57775783

5778-
bool matchCanFail =
5779-
llvm::any_of(params, [](const AnyFunctionType::Param &param) {
5780-
return param.getPlainType()->hasUnresolvedType();
5781-
});
5782-
57835784
// Determine whether this application has curried self.
57845785
bool skipCurriedSelf = apply ? hasCurriedSelf(cs, callee, apply) : true;
57855786
// Determine the parameter bindings.
@@ -5814,34 +5815,14 @@ Expr *ExprRewriter::coerceCallArguments(
58145815
// Apply labels to arguments.
58155816
AnyFunctionType::relabelParams(args, argLabels);
58165817

5817-
MatchCallArgumentListener listener;
58185818
auto unlabeledTrailingClosureIndex =
58195819
arg->getUnlabeledTrailingClosureIndexOfPackedArgument();
58205820

5821-
// Determine the trailing closure matching rule that was applied. This
5822-
// is only relevant for explicit calls and subscripts.
5823-
auto trailingClosureMatching = TrailingClosureMatching::Forward;
5824-
{
5825-
SmallVector<LocatorPathElt, 4> path;
5826-
auto anchor = locator.getLocatorParts(path);
5827-
if (!path.empty() && path.back().is<LocatorPathElt::ApplyArgument>() &&
5828-
!anchor.isExpr(ExprKind::UnresolvedDot)) {
5829-
auto locatorPtr = cs.getConstraintLocator(locator);
5830-
assert(solution.trailingClosureMatchingChoices.count(locatorPtr) == 1);
5831-
trailingClosureMatching = solution.trailingClosureMatchingChoices.find(
5832-
locatorPtr)->second;
5833-
}
5834-
}
5835-
5836-
auto callArgumentMatch = constraints::matchCallArguments(
5837-
args, params, paramInfo, unlabeledTrailingClosureIndex,
5838-
/*allowFixes=*/false, listener, trailingClosureMatching);
5839-
5840-
assert((matchCanFail || callArgumentMatch) &&
5841-
"Call arguments did not match up?");
5842-
(void)matchCanFail;
5843-
5844-
auto parameterBindings = std::move(callArgumentMatch->parameterBindings);
5821+
// Determine the parameter bindings that were applied.
5822+
auto *locatorPtr = cs.getConstraintLocator(locator);
5823+
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
5824+
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
5825+
->second.parameterBindings;
58455826

58465827
// We should either have parentheses or a tuple.
58475828
auto *argTuple = dyn_cast<TupleExpr>(arg);
@@ -6851,12 +6832,11 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
68516832
ConstraintLocator::ConstructorMember}));
68526833

68536834
solution.overloadChoices.insert({memberLoc, overload});
6854-
solution.trailingClosureMatchingChoices.insert(
6855-
{cs.getConstraintLocator(callLocator,
6856-
ConstraintLocator::ApplyArgument),
6857-
TrailingClosureMatching::Forward});
68586835
}
68596836

6837+
// Record the implicit call's parameter bindings and match direction.
6838+
solution.recordSingleArgMatchingChoice(callLocator);
6839+
68606840
finishApply(implicitInit, toType, callLocator, callLocator);
68616841
return implicitInit;
68626842
}

lib/Sema/CSGen.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,14 @@ namespace {
14541454
}
14551455

14561456
Type visitUnresolvedDotExpr(UnresolvedDotExpr *expr) {
1457+
// UnresolvedDot applies the base to remove a single curry level from a
1458+
// member reference without using an applicable function constraint so
1459+
// we record the call argument matching here so it can be found later when
1460+
// a solution is applied to the AST.
1461+
CS.recordMatchCallArgumentResult(
1462+
CS.getConstraintLocator(expr, ConstraintLocator::ApplyArgument),
1463+
MatchCallArgumentResult::forArity(1));
1464+
14571465
// If this is Builtin.type_join*, just return any type and move
14581466
// on since we're going to discard this, and creating any type
14591467
// variables for the reference will cause problems.

lib/Sema/CSSimplify.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,9 +1310,9 @@ ConstraintSystem::TypeMatchResult constraints::matchCallArguments(
13101310
}
13111311

13121312
selectedTrailingMatching = callArgumentMatch->trailingClosureMatching;
1313-
// Record the direction of matching used for this call.
1314-
cs.recordTrailingClosureMatch(cs.getConstraintLocator(locator),
1315-
selectedTrailingMatching);
1313+
// Record the matching direction and parameter bindings used for this call.
1314+
cs.recordMatchCallArgumentResult(cs.getConstraintLocator(locator),
1315+
*callArgumentMatch);
13161316

13171317
// If there was a disjunction because both forward and backward were
13181318
// possible, increase the score for forward matches to bias toward the
@@ -9939,10 +9939,10 @@ ConstraintSystem::simplifyApplicableFnConstraint(
99399939
// have an explicit inout argument.
99409940
if (type1.getPointer() == desugar2) {
99419941
if (!isOperator || !hasInOut()) {
9942-
recordTrailingClosureMatch(
9942+
recordMatchCallArgumentResult(
99439943
getConstraintLocator(
99449944
outerLocator.withPathElement(ConstraintLocator::ApplyArgument)),
9945-
TrailingClosureMatching::Forward);
9945+
MatchCallArgumentResult::forArity(func1->getNumParams()));
99469946
return SolutionKind::Solved;
99479947
}
99489948
}
@@ -11087,6 +11087,12 @@ void ConstraintSystem::recordPotentialHole(Type type) {
1108711087
});
1108811088
}
1108911089

11090+
void ConstraintSystem::recordMatchCallArgumentResult(
11091+
ConstraintLocator *locator, MatchCallArgumentResult result) {
11092+
assert(locator->isLastElement<LocatorPathElt::ApplyArgument>());
11093+
argumentMatchingChoices.push_back({locator, result});
11094+
}
11095+
1109011096
ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint(
1109111097
ConstraintFix *fix, Type type1, Type type2, ConstraintKind matchKind,
1109211098
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {

0 commit comments

Comments
 (0)