Skip to content

Commit 80e86dd

Browse files
authored
Merge pull request #70212 from xedin/rdar-119036147
[CSBinding] Infer key path root bindings transitively through context…
2 parents 5767e8f + 9875fcf commit 80e86dd

File tree

3 files changed

+168
-11
lines changed

3 files changed

+168
-11
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ class BindingSet {
388388
BindingSet(const PotentialBindings &info)
389389
: CS(info.CS), TypeVar(info.TypeVar), Info(info) {
390390
for (const auto &binding : info.Bindings)
391-
addBinding(binding);
391+
addBinding(binding, /*isTransitive=*/false);
392392

393393
for (auto *literal : info.Literals)
394394
addLiteralRequirement(literal);
@@ -479,7 +479,12 @@ class BindingSet {
479479
}
480480

481481
/// Check if this binding is viable for inclusion in the set.
482-
bool isViable(PotentialBinding &binding);
482+
///
483+
/// \param binding The binding to validate.
484+
/// \param isTransitive Indicates whether this binding has been
485+
/// acquired through transitive inference and requires extra
486+
/// checking.
487+
bool isViable(PotentialBinding &binding, bool isTransitive);
483488

484489
explicit operator bool() const {
485490
return hasViableBindings() || isDirectHole();
@@ -618,7 +623,13 @@ class BindingSet {
618623
void dump(llvm::raw_ostream &out, unsigned indent) const;
619624

620625
private:
621-
void addBinding(PotentialBinding binding);
626+
/// Add a new binding to the set.
627+
///
628+
/// \param binding The binding to add.
629+
/// \param isTransitive Indicates whether this binding has been
630+
/// acquired through transitive inference and requires validity
631+
/// checking.
632+
void addBinding(PotentialBinding binding, bool isTransitive);
622633

623634
void addLiteralRequirement(Constraint *literal);
624635

lib/Sema/CSBindings.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ using namespace swift;
2727
using namespace constraints;
2828
using namespace inference;
2929

30+
static llvm::Optional<Type> checkTypeOfBinding(TypeVariableType *typeVar,
31+
Type type);
32+
3033
bool BindingSet::forClosureResult() const {
3134
return Info.TypeVar->getImpl().isClosureResultType();
3235
}
@@ -457,9 +460,37 @@ void BindingSet::inferTransitiveBindings(
457460
inferredRootTy = fnType->getParams()[0].getParameterType();
458461
}
459462

460-
if (inferredRootTy && !inferredRootTy->isTypeVariableOrMember())
461-
addBinding(
462-
binding.withSameSource(inferredRootTy, BindingKind::Exact));
463+
if (inferredRootTy) {
464+
// If contextual root is not yet resolved, let's try to see if
465+
// there are any bindings in its set. The bindings could be
466+
// transitively used because conversions between generic arguments
467+
// are not allowed.
468+
if (auto *contextualRootVar = inferredRootTy->getAs<TypeVariableType>()) {
469+
auto rootBindings = inferredBindings.find(contextualRootVar);
470+
if (rootBindings != inferredBindings.end()) {
471+
auto &bindings = rootBindings->getSecond();
472+
473+
// Don't infer if root is not yet fully resolved.
474+
if (bindings.isDelayed())
475+
continue;
476+
477+
// Copy the bindings over to the root.
478+
for (const auto &binding : bindings.Bindings)
479+
addBinding(binding, /*isTransitive=*/true);
480+
481+
// Make a note that the key path root is transitively adjacent
482+
// to contextual root type variable and all of its variables.
483+
// This is important for ranking.
484+
AdjacentVars.insert(contextualRootVar);
485+
AdjacentVars.insert(bindings.AdjacentVars.begin(),
486+
bindings.AdjacentVars.end());
487+
}
488+
} else {
489+
addBinding(
490+
binding.withSameSource(inferredRootTy, BindingKind::Exact),
491+
/*isTransitive=*/true);
492+
}
493+
}
463494
}
464495
}
465496
}
@@ -526,7 +557,8 @@ void BindingSet::inferTransitiveBindings(
526557
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
527558
continue;
528559

529-
addBinding(binding.withSameSource(type, BindingKind::Supertypes));
560+
addBinding(binding.withSameSource(type, BindingKind::Supertypes),
561+
/*isTransitive=*/true);
530562
}
531563
}
532564
}
@@ -604,7 +636,8 @@ void BindingSet::finalize(
604636
continue;
605637
}
606638

607-
addBinding({protocolTy, AllowedBindingKind::Exact, constraint});
639+
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
640+
/*isTransitive=*/false);
608641
}
609642
}
610643
}
@@ -713,11 +746,11 @@ void BindingSet::finalize(
713746
}
714747
}
715748

716-
void BindingSet::addBinding(PotentialBinding binding) {
749+
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
717750
if (Bindings.count(binding))
718751
return;
719752

720-
if (!isViable(binding))
753+
if (!isViable(binding, isTransitive))
721754
return;
722755

723756
SmallPtrSet<TypeVariableType *, 4> referencedTypeVars;
@@ -1138,14 +1171,17 @@ void PotentialBindings::addLiteral(Constraint *constraint) {
11381171
Literals.insert(constraint);
11391172
}
11401173

1141-
bool BindingSet::isViable(PotentialBinding &binding) {
1174+
bool BindingSet::isViable(PotentialBinding &binding, bool isTransitive) {
11421175
// Prevent against checking against the same opened nominal type
11431176
// over and over again. Doing so means redundant work in the best
11441177
// case. In the worst case, we'll produce lots of duplicate solutions
11451178
// for this constraint system, which is problematic for overload
11461179
// resolution.
11471180
auto type = binding.BindingType;
11481181

1182+
if (isTransitive && !checkTypeOfBinding(TypeVar, type))
1183+
return false;
1184+
11491185
auto *NTD = type->getAnyNominal();
11501186
if (!NTD)
11511187
return true;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: %target-typecheck-verify-swift -target %target-cpu-apple-macosx10.15 -solver-expression-time-threshold=1 -swift-version 5
2+
3+
// REQUIRES: OS=macosx
4+
5+
import Combine
6+
7+
enum Status {
8+
case up
9+
case down
10+
}
11+
12+
protocol StatusMonitor {
13+
var statusPublisher: AnyPublisher<Status, Never> { get }
14+
}
15+
16+
protocol UIController {}
17+
protocol ControllerProtocol {}
18+
19+
class TestViewController : UIController, ControllerProtocol {
20+
}
21+
22+
class OtherController {
23+
var innerController: (any UIController & ControllerProtocol)? = nil
24+
}
25+
26+
class Test1 {
27+
var monitor: StatusMonitor
28+
29+
var subscriptions: [AnyCancellable] = []
30+
var status: Status? = nil
31+
var statuses: [Status]? = nil
32+
33+
init(monitor: StatusMonitor) {
34+
self.monitor = monitor
35+
}
36+
37+
func simpleMapTest() {
38+
monitor.statusPublisher
39+
.map { $0 }
40+
.assign(to: \.status, on: self) // Ok
41+
.store(in: &subscriptions)
42+
}
43+
44+
func transformationTest() {
45+
monitor.statusPublisher
46+
.map { _ in (0...1).map { _ in .up } }
47+
.assign(to: \.statuses, on: self) // Ok
48+
.store(in: &subscriptions)
49+
}
50+
}
51+
52+
class FilteringTest {
53+
@Published var flag = false
54+
55+
func test(viewController: inout OtherController) {
56+
_ = $flag.filter { !$0 }
57+
.map { _ in TestViewController() }
58+
.first()
59+
.handleEvents(receiveOutput: { _ in
60+
print("event")
61+
})
62+
.assign(to: \.innerController, on: viewController) // Ok
63+
}
64+
}
65+
66+
extension Sequence {
67+
func sorted<T: Comparable>(by keyPath: KeyPath<Element, T>) -> [Element] {
68+
[]
69+
}
70+
}
71+
72+
func testCollectionUpcastWithTupleLabelErasure() {
73+
struct Item {}
74+
75+
enum Info : Int, Hashable {
76+
case one = 1
77+
}
78+
79+
80+
func test(data: [Info: [Item]]) -> [(Info, [Item])] {
81+
data.map { $0 }
82+
.sorted(by: \.key.rawValue) // Ok
83+
}
84+
}
85+
86+
do {
87+
struct URL {
88+
var path: String
89+
func appendingPathComponent(_: String) -> URL { fatalError() }
90+
}
91+
92+
struct EntryPoint {
93+
var directory: URL { fatalError() }
94+
}
95+
96+
func test(entryPoint: EntryPoint, data: [[String]]) {
97+
let _ = data.map { suffixes in
98+
let elements = ["a", "b"]
99+
.flatMap { dir in
100+
let directory = entryPoint.directory.appendingPathComponent(dir)
101+
return suffixes.map { suffix in
102+
directory.appendingPathComponent("\(suffix)")
103+
}
104+
}
105+
.map(\.path) // Ok
106+
107+
return elements.joined(separator: ",")
108+
}
109+
}
110+
}

0 commit comments

Comments
 (0)