Skip to content

[CSBinding] Infer key path root bindings transitively through context… #70212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions include/swift/Sema/CSBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class BindingSet {
BindingSet(const PotentialBindings &info)
: CS(info.CS), TypeVar(info.TypeVar), Info(info) {
for (const auto &binding : info.Bindings)
addBinding(binding);
addBinding(binding, /*isTransitive=*/false);

for (auto *literal : info.Literals)
addLiteralRequirement(literal);
Expand Down Expand Up @@ -479,7 +479,12 @@ class BindingSet {
}

/// Check if this binding is viable for inclusion in the set.
bool isViable(PotentialBinding &binding);
///
/// \param binding The binding to validate.
/// \param isTransitive Indicates whether this binding has been
/// acquired through transitive inference and requires extra
/// checking.
bool isViable(PotentialBinding &binding, bool isTransitive);

explicit operator bool() const {
return hasViableBindings() || isDirectHole();
Expand Down Expand Up @@ -618,7 +623,13 @@ class BindingSet {
void dump(llvm::raw_ostream &out, unsigned indent) const;

private:
void addBinding(PotentialBinding binding);
/// Add a new binding to the set.
///
/// \param binding The binding to add.
/// \param isTransitive Indicates whether this binding has been
/// acquired through transitive inference and requires validity
/// checking.
void addBinding(PotentialBinding binding, bool isTransitive);

void addLiteralRequirement(Constraint *literal);

Expand Down
52 changes: 44 additions & 8 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ using namespace swift;
using namespace constraints;
using namespace inference;

static llvm::Optional<Type> checkTypeOfBinding(TypeVariableType *typeVar,
Type type);

bool BindingSet::forClosureResult() const {
return Info.TypeVar->getImpl().isClosureResultType();
}
Expand Down Expand Up @@ -457,9 +460,37 @@ void BindingSet::inferTransitiveBindings(
inferredRootTy = fnType->getParams()[0].getParameterType();
}

if (inferredRootTy && !inferredRootTy->isTypeVariableOrMember())
addBinding(
binding.withSameSource(inferredRootTy, BindingKind::Exact));
if (inferredRootTy) {
// If contextual root is not yet resolved, let's try to see if
// there are any bindings in its set. The bindings could be
// transitively used because conversions between generic arguments
// are not allowed.
if (auto *contextualRootVar = inferredRootTy->getAs<TypeVariableType>()) {
auto rootBindings = inferredBindings.find(contextualRootVar);
if (rootBindings != inferredBindings.end()) {
auto &bindings = rootBindings->getSecond();

// Don't infer if root is not yet fully resolved.
if (bindings.isDelayed())
continue;

// Copy the bindings over to the root.
for (const auto &binding : bindings.Bindings)
addBinding(binding, /*isTransitive=*/true);

// Make a note that the key path root is transitively adjacent
// to contextual root type variable and all of its variables.
// This is important for ranking.
AdjacentVars.insert(contextualRootVar);
AdjacentVars.insert(bindings.AdjacentVars.begin(),
bindings.AdjacentVars.end());
}
} else {
addBinding(
binding.withSameSource(inferredRootTy, BindingKind::Exact),
/*isTransitive=*/true);
}
}
}
}
}
Expand Down Expand Up @@ -526,7 +557,8 @@ void BindingSet::inferTransitiveBindings(
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
continue;

addBinding(binding.withSameSource(type, BindingKind::Supertypes));
addBinding(binding.withSameSource(type, BindingKind::Supertypes),
/*isTransitive=*/true);
}
}
}
Expand Down Expand Up @@ -604,7 +636,8 @@ void BindingSet::finalize(
continue;
}

addBinding({protocolTy, AllowedBindingKind::Exact, constraint});
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
}
Expand Down Expand Up @@ -713,11 +746,11 @@ void BindingSet::finalize(
}
}

void BindingSet::addBinding(PotentialBinding binding) {
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
if (Bindings.count(binding))
return;

if (!isViable(binding))
if (!isViable(binding, isTransitive))
return;

SmallPtrSet<TypeVariableType *, 4> referencedTypeVars;
Expand Down Expand Up @@ -1138,14 +1171,17 @@ void PotentialBindings::addLiteral(Constraint *constraint) {
Literals.insert(constraint);
}

bool BindingSet::isViable(PotentialBinding &binding) {
bool BindingSet::isViable(PotentialBinding &binding, bool isTransitive) {
// Prevent against checking against the same opened nominal type
// over and over again. Doing so means redundant work in the best
// case. In the worst case, we'll produce lots of duplicate solutions
// for this constraint system, which is problematic for overload
// resolution.
auto type = binding.BindingType;

if (isTransitive && !checkTypeOfBinding(TypeVar, type))
return false;

auto *NTD = type->getAnyNominal();
if (!NTD)
return true;
Expand Down
110 changes: 110 additions & 0 deletions test/Sema/keypath_bidirectional_inference.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// RUN: %target-typecheck-verify-swift -target %target-cpu-apple-macosx10.15 -solver-expression-time-threshold=1 -swift-version 5

// REQUIRES: OS=macosx

import Combine

enum Status {
case up
case down
}

protocol StatusMonitor {
var statusPublisher: AnyPublisher<Status, Never> { get }
}

protocol UIController {}
protocol ControllerProtocol {}

class TestViewController : UIController, ControllerProtocol {
}

class OtherController {
var innerController: (any UIController & ControllerProtocol)? = nil
}

class Test1 {
var monitor: StatusMonitor

var subscriptions: [AnyCancellable] = []
var status: Status? = nil
var statuses: [Status]? = nil

init(monitor: StatusMonitor) {
self.monitor = monitor
}

func simpleMapTest() {
monitor.statusPublisher
.map { $0 }
.assign(to: \.status, on: self) // Ok
.store(in: &subscriptions)
}

func transformationTest() {
monitor.statusPublisher
.map { _ in (0...1).map { _ in .up } }
.assign(to: \.statuses, on: self) // Ok
.store(in: &subscriptions)
}
}

class FilteringTest {
@Published var flag = false

func test(viewController: inout OtherController) {
_ = $flag.filter { !$0 }
.map { _ in TestViewController() }
.first()
.handleEvents(receiveOutput: { _ in
print("event")
})
.assign(to: \.innerController, on: viewController) // Ok
}
}

extension Sequence {
func sorted<T: Comparable>(by keyPath: KeyPath<Element, T>) -> [Element] {
[]
}
}

func testCollectionUpcastWithTupleLabelErasure() {
struct Item {}

enum Info : Int, Hashable {
case one = 1
}


func test(data: [Info: [Item]]) -> [(Info, [Item])] {
data.map { $0 }
.sorted(by: \.key.rawValue) // Ok
}
}

do {
struct URL {
var path: String
func appendingPathComponent(_: String) -> URL { fatalError() }
}

struct EntryPoint {
var directory: URL { fatalError() }
}

func test(entryPoint: EntryPoint, data: [[String]]) {
let _ = data.map { suffixes in
let elements = ["a", "b"]
.flatMap { dir in
let directory = entryPoint.directory.appendingPathComponent(dir)
return suffixes.map { suffix in
directory.appendingPathComponent("\(suffix)")
}
}
.map(\.path) // Ok

return elements.joined(separator: ",")
}
}
}