Skip to content

Commit 4c05a35

Browse files
authored
[CS] Simplify getCalleeDeclAndArgs (#27102)
[CS] Simplify getCalleeDeclAndArgs
2 parents 998d2ed + dafcaeb commit 4c05a35

File tree

8 files changed

+121
-151
lines changed

8 files changed

+121
-151
lines changed

lib/Sema/CSDiag.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6326,21 +6326,17 @@ void FailureDiagnosis::diagnoseAmbiguity(Expr *E) {
63266326
/// If an UnresolvedDotExpr, SubscriptMember, etc has been resolved by the
63276327
/// constraint system, return the decl that it references.
63286328
ValueDecl *ConstraintSystem::findResolvedMemberRef(ConstraintLocator *locator) {
6329-
auto *resolvedOverloadSets = this->getResolvedOverloadSets();
6330-
if (!resolvedOverloadSets) return nullptr;
6331-
63326329
// Search through the resolvedOverloadSets to see if we have a resolution for
63336330
// this member. This is an O(n) search, but only happens when producing an
63346331
// error diagnostic.
6335-
for (auto resolved = resolvedOverloadSets;
6336-
resolved; resolved = resolved->Previous) {
6337-
if (resolved->Locator != locator) continue;
6338-
6339-
// We only handle the simplest decl binding.
6340-
if (resolved->Choice.getKind() != OverloadChoiceKind::Decl)
6341-
return nullptr;
6342-
return resolved->Choice.getDecl();
6343-
}
6344-
6345-
return nullptr;
6332+
auto *overload = findSelectedOverloadFor(locator);
6333+
if (!overload)
6334+
return nullptr;
6335+
6336+
// We only want to handle the simplest decl binding.
6337+
auto choice = overload->Choice;
6338+
if (choice.getKind() != OverloadChoiceKind::Decl)
6339+
return nullptr;
6340+
6341+
return choice.getDecl();
63466342
}

lib/Sema/CSDiagnostics.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,7 @@ class FailureDiagnostic {
144144
/// by the constraint solver.
145145
ResolvedOverloadSetListItem *
146146
getResolvedOverload(ConstraintLocator *locator) const {
147-
auto resolvedOverload = CS.getResolvedOverloadSets();
148-
while (resolvedOverload) {
149-
if (resolvedOverload->Locator == locator)
150-
return resolvedOverload;
151-
resolvedOverload = resolvedOverload->Previous;
152-
}
153-
return nullptr;
147+
return CS.findSelectedOverloadFor(locator);
154148
}
155149

156150
/// Retrive the constraint locator for the given anchor and

lib/Sema/CSGen.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -982,8 +982,8 @@ namespace {
982982
/// Add constraints for a subscript operation.
983983
Type addSubscriptConstraints(
984984
Expr *anchor, Type baseTy, Expr *index,
985-
ValueDecl *declOrNull,
986-
ConstraintLocator *locator = nullptr,
985+
ValueDecl *declOrNull, ArrayRef<Identifier> argLabels,
986+
bool hasTrailingClosure, ConstraintLocator *locator = nullptr,
987987
SmallVectorImpl<TypeVariableType *> *addedTypeVars = nullptr) {
988988
// Locators used in this expression.
989989
if (locator == nullptr)
@@ -999,6 +999,8 @@ namespace {
999999
CS.getConstraintLocator(locator,
10001000
ConstraintLocator::FunctionResult);
10011001

1002+
associateArgumentLabels(memberLocator, {argLabels, hasTrailingClosure});
1003+
10021004
Type outputTy;
10031005

10041006
// For an integer subscript expression on an array slice type, instead of
@@ -1229,8 +1231,9 @@ namespace {
12291231
}
12301232

12311233
Type visitObjectLiteralExpr(ObjectLiteralExpr *expr) {
1234+
auto *exprLoc = CS.getConstraintLocator(expr);
12321235
associateArgumentLabels(
1233-
expr, {expr->getArgumentLabels(), expr->hasTrailingClosure()});
1236+
exprLoc, {expr->getArgumentLabels(), expr->hasTrailingClosure()});
12341237

12351238
// If the expression has already been assigned a type; just use that type.
12361239
if (expr->getType())
@@ -1244,13 +1247,13 @@ namespace {
12441247
return nullptr;
12451248
}
12461249

1247-
auto tv = CS.createTypeVariable(CS.getConstraintLocator(expr),
1250+
auto tv = CS.createTypeVariable(exprLoc,
12481251
TVO_PrefersSubtypeBinding |
12491252
TVO_CanBindToNoEscape);
12501253

12511254
CS.addConstraint(ConstraintKind::LiteralConformsTo, tv,
12521255
protocol->getDeclaredType(),
1253-
CS.getConstraintLocator(expr));
1256+
exprLoc);
12541257

12551258
// The arguments are required to be argument-convertible to the
12561259
// idealized parameter type of the initializer, which generally
@@ -1511,7 +1514,8 @@ namespace {
15111514
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));
15121515

15131516
associateArgumentLabels(
1514-
expr, {expr->getArgumentLabels(), expr->hasTrailingClosure()});
1517+
CS.getConstraintLocator(expr),
1518+
{expr->getArgumentLabels(), expr->hasTrailingClosure()});
15151519
return baseTy;
15161520
}
15171521

@@ -1746,12 +1750,10 @@ namespace {
17461750
return Type();
17471751
}
17481752

1749-
associateArgumentLabels(
1750-
expr, {expr->getArgumentLabels(), expr->hasTrailingClosure()});
1751-
17521753
return addSubscriptConstraints(expr, CS.getType(expr->getBase()),
17531754
expr->getIndex(),
1754-
decl);
1755+
decl, expr->getArgumentLabels(),
1756+
expr->hasTrailingClosure());
17551757
}
17561758

17571759
Type visitArrayExpr(ArrayExpr *expr) {
@@ -1995,10 +1997,10 @@ namespace {
19951997
}
19961998

19971999
Type visitDynamicSubscriptExpr(DynamicSubscriptExpr *expr) {
1998-
associateArgumentLabels(
1999-
expr, {expr->getArgumentLabels(), expr->hasTrailingClosure()});
20002000
return addSubscriptConstraints(expr, CS.getType(expr->getBase()),
2001-
expr->getIndex(), nullptr);
2001+
expr->getIndex(), /*decl*/ nullptr,
2002+
expr->getArgumentLabels(),
2003+
expr->hasTrailingClosure());
20022004
}
20032005

20042006
Type visitTupleElementExpr(TupleElementExpr *expr) {
@@ -2488,7 +2490,8 @@ namespace {
24882490

24892491
SmallVector<Identifier, 4> scratch;
24902492
associateArgumentLabels(
2491-
expr, {expr->getArgumentLabels(scratch), expr->hasTrailingClosure()},
2493+
CS.getConstraintLocator(expr),
2494+
{expr->getArgumentLabels(scratch), expr->hasTrailingClosure()},
24922495
/*labelsArePermanent=*/isa<CallExpr>(expr));
24932496

24942497
if (auto *UDE = dyn_cast<UnresolvedDotExpr>(fnExpr)) {
@@ -3017,7 +3020,10 @@ namespace {
30173020
// re-type-check the constraints during failure diagnosis.
30183021
case KeyPathExpr::Component::Kind::Subscript: {
30193022
base = addSubscriptConstraints(E, base, component.getIndexExpr(),
3020-
/*decl*/ nullptr, memberLocator,
3023+
/*decl*/ nullptr,
3024+
component.getSubscriptLabels(),
3025+
/*hasTrailingClosure*/ false,
3026+
memberLocator,
30213027
&componentTypeVars);
30223028
break;
30233029
}
@@ -3251,15 +3257,13 @@ namespace {
32513257
llvm_unreachable("unhandled operation");
32523258
}
32533259

3254-
void associateArgumentLabels(Expr *expr,
3260+
void associateArgumentLabels(ConstraintLocator *locator,
32553261
ConstraintSystem::ArgumentInfo info,
32563262
bool labelsArePermanent = true) {
3257-
assert(expr);
3263+
assert(locator && locator->getAnchor());
32583264
// Record the labels.
32593265
if (!labelsArePermanent)
32603266
info.Labels = CS.allocateCopy(info.Labels);
3261-
3262-
auto *locator = CS.getConstraintLocator(expr);
32633267
CS.ArgumentInfos[CS.getArgumentInfoLocator(locator)] = info;
32643268
}
32653269
};

lib/Sema/CSSimplify.cpp

Lines changed: 31 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -735,124 +735,54 @@ static std::tuple<ValueDecl *, bool, ArrayRef<Identifier>, bool,
735735
ConstraintLocator *>
736736
getCalleeDeclAndArgs(ConstraintSystem &cs,
737737
ConstraintLocatorBuilder callBuilder) {
738-
ArrayRef<Identifier> argLabels;
739-
bool hasTrailingClosure = false;
740-
ConstraintLocator *targetLocator = nullptr;
738+
auto formUnknownCallee =
739+
[]() -> std::tuple<ValueDecl *, bool, ArrayRef<Identifier>, bool,
740+
ConstraintLocator *> {
741+
return std::make_tuple(/*decl*/ nullptr, /*hasAppliedSelf*/ false,
742+
/*argLabels*/ ArrayRef<Identifier>(),
743+
/*hasTrailingClosure*/ false,
744+
/*calleeLocator*/ nullptr);
745+
};
741746

742747
auto *callLocator = cs.getConstraintLocator(callBuilder);
743748
auto *callExpr = callLocator->getAnchor();
744749

745750
// Break down the call.
746751
if (!callExpr)
747-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
748-
hasTrailingClosure, targetLocator);
752+
return formUnknownCallee();
749753

750-
auto path = callLocator->getPath();
751754
// Our remaining path can only be 'ApplyArgument'.
755+
auto path = callLocator->getPath();
752756
if (!path.empty() &&
753757
!(path.size() <= 2 &&
754758
path.back().getKind() == ConstraintLocator::ApplyArgument))
755-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
756-
hasTrailingClosure, targetLocator);
759+
return formUnknownCallee();
757760

758761
// Dig out the callee information.
759-
if (auto argInfo = cs.getArgumentInfo(callLocator)) {
760-
argLabels = argInfo->Labels;
761-
hasTrailingClosure = argInfo->HasTrailingClosure;
762-
targetLocator = cs.getConstraintLocator(
763-
isa<CallExpr>(callExpr) ? cast<CallExpr>(callExpr)->getDirectCallee()
764-
: callExpr);
765-
} else if (auto keyPath = dyn_cast<KeyPathExpr>(callExpr)) {
766-
if (path.size() != 2)
767-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
768-
hasTrailingClosure, nullptr);
769-
770-
// We must have a KeyPathComponent followed by an ApplyArgument.
771-
auto componentElt = path[0].getAs<LocatorPathElt::KeyPathComponent>();
772-
if (!componentElt || path[1].getKind() != ConstraintLocator::ApplyArgument)
773-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
774-
hasTrailingClosure, nullptr);
775-
776-
auto componentIndex = componentElt->getIndex();
777-
if (componentIndex >= keyPath->getComponents().size())
778-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
779-
hasTrailingClosure, nullptr);
780-
781-
auto &component = keyPath->getComponents()[componentIndex];
782-
switch (component.getKind()) {
783-
case KeyPathExpr::Component::Kind::Subscript:
784-
case KeyPathExpr::Component::Kind::UnresolvedSubscript:
785-
targetLocator = cs.getConstraintLocator(callExpr, path[0]);
786-
argLabels = component.getSubscriptLabels();
787-
hasTrailingClosure = false; // key paths don't support trailing closures
788-
break;
789-
790-
case KeyPathExpr::Component::Kind::Invalid:
791-
case KeyPathExpr::Component::Kind::UnresolvedProperty:
792-
case KeyPathExpr::Component::Kind::Property:
793-
case KeyPathExpr::Component::Kind::OptionalForce:
794-
case KeyPathExpr::Component::Kind::OptionalChain:
795-
case KeyPathExpr::Component::Kind::OptionalWrap:
796-
case KeyPathExpr::Component::Kind::Identity:
797-
case KeyPathExpr::Component::Kind::TupleElement:
798-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
799-
hasTrailingClosure, nullptr);
800-
}
801-
} else {
802-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
803-
hasTrailingClosure, targetLocator);
804-
}
762+
auto argInfo = cs.getArgumentInfo(callLocator);
763+
if (!argInfo)
764+
return formUnknownCallee();
765+
766+
auto argLabels = argInfo->Labels;
767+
auto hasTrailingClosure = argInfo->HasTrailingClosure;
768+
auto calleeLocator = cs.getCalleeLocator(callLocator);
805769

806770
// Find the overload choice corresponding to the callee locator.
807771
// FIXME: This linearly walks the list of resolved overloads, which is
808772
// potentially very expensive.
809-
Optional<OverloadChoice> choice;
810-
ConstraintLocator *calleeLocator = nullptr;
811-
for (auto resolved = cs.getResolvedOverloadSets(); resolved;
812-
resolved = resolved->Previous) {
813-
// FIXME: Workaround null locators.
814-
if (!resolved->Locator) continue;
815-
816-
auto resolvedLocator = resolved->Locator;
817-
SmallVector<LocatorPathElt, 4> resolvedPath(
818-
resolvedLocator->getPath().begin(),
819-
resolvedLocator->getPath().end());
820-
if (!resolvedPath.empty() &&
821-
(resolvedPath.back().getKind() == ConstraintLocator::SubscriptMember ||
822-
resolvedPath.back().getKind() == ConstraintLocator::Member ||
823-
resolvedPath.back().getKind() == ConstraintLocator::UnresolvedMember ||
824-
resolvedPath.back().getKind() ==
825-
ConstraintLocator::ConstructorMember)) {
826-
resolvedPath.pop_back();
827-
resolvedLocator = cs.getConstraintLocator(
828-
resolvedLocator->getAnchor(),
829-
resolvedPath,
830-
resolvedLocator->getSummaryFlags());
831-
}
832-
833-
SourceRange range;
834-
resolvedLocator = simplifyLocator(cs, resolvedLocator, range);
835-
836-
if (resolvedLocator == targetLocator) {
837-
calleeLocator = resolved->Locator;
838-
choice = resolved->Choice;
839-
break;
840-
}
841-
}
842-
843-
// If we didn't find any matching overloads, we're done.
844-
if (!choice)
845-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
846-
hasTrailingClosure, nullptr);
847-
848-
// If there's a declaration, return it.
849-
if (auto *decl = choice->getDeclOrNull()) {
850-
return std::make_tuple(decl, hasAppliedSelf(cs, *choice), argLabels,
851-
hasTrailingClosure, calleeLocator);
852-
}
853-
854-
return std::make_tuple(nullptr, /*hasAppliedSelf=*/false, argLabels,
855-
hasTrailingClosure, calleeLocator);
773+
auto selectedOverload = cs.findSelectedOverloadFor(calleeLocator);
774+
775+
// If we didn't find any matching overloads, we're done. Just return the
776+
// argument info.
777+
if (!selectedOverload)
778+
return std::make_tuple(/*decl*/ nullptr, /*hasAppliedSelf*/ false,
779+
argLabels, hasTrailingClosure,
780+
/*calleeLocator*/ nullptr);
781+
782+
// Return the found declaration, assuming there is one.
783+
auto choice = selectedOverload->Choice;
784+
return std::make_tuple(choice.getDeclOrNull(), hasAppliedSelf(cs, choice),
785+
argLabels, hasTrailingClosure, calleeLocator);
856786
}
857787

858788
class ArgumentFailureTracker : public MatchCallArgumentListener {

lib/Sema/ConstraintSystem.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,17 @@ class ConstraintSystem {
15921592
return resolvedOverloadSets;
15931593
}
15941594

1595+
ResolvedOverloadSetListItem *
1596+
findSelectedOverloadFor(ConstraintLocator *locator) const {
1597+
auto resolvedOverload = getResolvedOverloadSets();
1598+
while (resolvedOverload) {
1599+
if (resolvedOverload->Locator == locator)
1600+
return resolvedOverload;
1601+
resolvedOverload = resolvedOverload->Previous;
1602+
}
1603+
return nullptr;
1604+
}
1605+
15951606
ResolvedOverloadSetListItem *findSelectedOverloadFor(Expr *expr) const {
15961607
auto resolvedOverload = getResolvedOverloadSets();
15971608
while (resolvedOverload) {

test/Constraints/function_builder.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,18 @@ print(mbuilders.methodBuilder(13))
212212
// CHECK: ("propertyBuilder", 12)
213213
print(mbuilders.propertyBuilder)
214214

215+
// SR-11439: Operator builders
216+
infix operator ^^^
217+
func ^^^ (lhs: Int, @TupleBuilder rhs: (Int) -> (String, Int)) -> (String, Int) {
218+
return rhs(lhs)
219+
}
220+
221+
// CHECK: ("hello", 6)
222+
print(5 ^^^ {
223+
"hello"
224+
$0 + 1
225+
})
226+
215227
struct Tagged<Tag, Entity> {
216228
let tag: Tag
217229
let entity: Entity

test/Constraints/function_builder_diags.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,30 @@ func test_51167632() -> some P {
172172
// expected-note@-1 {{explicitly specify the generic arguments to fix this issue}} {{10-10=<<#L: P#>>}}
173173
})
174174
}
175+
176+
struct SR11440 {
177+
typealias ReturnsTuple<T> = () -> (T, T)
178+
subscript<T, U>(@TupleBuilder x: ReturnsTuple<T>) -> (ReturnsTuple<U>) -> Void { //expected-note {{in call to 'subscript(_:)'}}
179+
return { _ in }
180+
}
181+
182+
func foo() {
183+
// This is okay, we apply the function builder for the subscript arg.
184+
self[{
185+
5
186+
5
187+
}]({
188+
(5, 5)
189+
})
190+
191+
// But we shouldn't perform the transform for the argument to the call
192+
// made on the function returned from the subscript.
193+
self[{ // expected-error {{generic parameter 'U' could not be inferred}}
194+
5
195+
5
196+
}]({
197+
5
198+
5
199+
})
200+
}
201+
}

0 commit comments

Comments
 (0)