Skip to content

[SE-0326] Re-enable multi-statement closure inference by default #41730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ namespace swift {

/// Enable experimental support for type inference through multi-statement
/// closures.
bool EnableMultiStatementClosureInference = false;
bool EnableMultiStatementClosureInference = true;

/// Enable experimental support for generic parameter inference in
/// parameter positions from associated default expressions.
Expand Down
13 changes: 13 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,19 @@ class LocatorPathElt::ClosureBodyElement final
}
};

class LocatorPathElt::PatternBindingElement final
: public StoredIntegerElement<1> {
public:
PatternBindingElement(unsigned index)
: StoredIntegerElement(ConstraintLocator::PatternBindingElement, index) {}

unsigned getIndex() const { return getValue(); }

static bool classof(const LocatorPathElt *elt) {
return elt->getKind() == ConstraintLocator::PatternBindingElement;
}
};

namespace details {
template <typename CustomPathElement>
class PathElement {
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Sema/ConstraintLocatorPathElts.def
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ SIMPLE_LOCATOR_PATH_ELT(ImplicitDynamicMemberSubscript)
/// The element of the closure body e.g. statement, declaration, or expression.
CUSTOM_LOCATOR_PATH_ELT(ClosureBodyElement)

/// The element of the pattern binding declaration.
CUSTOM_LOCATOR_PATH_ELT(PatternBindingElement)

#undef LOCATOR_PATH_ELT
#undef CUSTOM_LOCATOR_PATH_ELT
#undef SIMPLE_LOCATOR_PATH_ELT
Expand Down
93 changes: 90 additions & 3 deletions lib/Sema/CSClosure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,12 +483,81 @@ class ClosureConstraintGenerator
});
}

void visitPatternBinding(PatternBindingDecl *patternBinding,
SmallVectorImpl<ElementInfo> &patterns) {
auto *baseLoc = cs.getConstraintLocator(
locator, LocatorPathElt::ClosureBodyElement(patternBinding));

for (unsigned index : range(patternBinding->getNumPatternEntries())) {
auto *pattern = TypeChecker::resolvePattern(
patternBinding->getPattern(index), patternBinding->getDeclContext(),
/*isStmtCondition=*/true);

if (!pattern) {
hadError = true;
return;
}

// Reset binding to point to the resolved pattern. This is required
// before calling `forPatternBindingDecl`.
patternBinding->setPattern(index, pattern,
patternBinding->getInitContext(index));

patterns.push_back(makeElement(
patternBinding,
cs.getConstraintLocator(
baseLoc, LocatorPathElt::PatternBindingElement(index))));
}
}

void visitPatternBindingElement(PatternBindingDecl *patternBinding) {
assert(locator->isLastElement<LocatorPathElt::PatternBindingElement>());

auto index =
locator->castLastElementTo<LocatorPathElt::PatternBindingElement>()
.getIndex();

auto contextualPattern =
ContextualPattern::forPatternBindingDecl(patternBinding, index);
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);

// Fail early if pattern couldn't be type-checked.
if (!patternType || patternType->hasError()) {
hadError = true;
return;
}

auto *pattern = patternBinding->getPattern(index);
auto *init = patternBinding->getInit(index);

if (!init && patternBinding->isDefaultInitializable(index) &&
pattern->hasStorage()) {
init = TypeChecker::buildDefaultInitializer(patternType);
}

auto target = init ? SolutionApplicationTarget::forInitialization(
init, patternBinding->getDeclContext(),
patternType, patternBinding, index,
/*bindPatternVarsOneWay=*/false)
: SolutionApplicationTarget::forUninitializedVar(
patternBinding, index, patternType);

if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) {
hadError = true;
return;
}

// Keep track of this binding entry.
cs.setSolutionApplicationTarget({patternBinding, index}, target);
}

void visitDecl(Decl *decl) {
if (isSupportedMultiStatementClosure()) {
if (auto patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
SolutionApplicationTarget target(patternBinding);
if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow))
hadError = true;
if (locator->isLastElement<LocatorPathElt::PatternBindingElement>())
visitPatternBindingElement(patternBinding);
else
llvm_unreachable("cannot visit pattern binding directly");
return;
}
}
Expand Down Expand Up @@ -788,6 +857,13 @@ class ClosureConstraintGenerator
element.is<Expr *>() &&
(!ctx.LangOpts.Playground && !ctx.LangOpts.DebuggerSupport);

if (auto *decl = element.dyn_cast<Decl *>()) {
if (auto *PDB = dyn_cast<PatternBindingDecl>(decl)) {
visitPatternBinding(PDB, elements);
continue;
}
}

elements.push_back(makeElement(
element,
cs.getConstraintLocator(
Expand Down Expand Up @@ -1600,6 +1676,17 @@ void ConjunctionElement::findReferencedVariables(

TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars);

if (auto *patternBinding =
dyn_cast_or_null<PatternBindingDecl>(element.dyn_cast<Decl *>())) {
if (auto patternBindingElt =
locator
->getLastElementAs<LocatorPathElt::PatternBindingElement>()) {
if (auto *init = patternBinding->getInit(patternBindingElt->getIndex()))
init->walk(refFinder);
return;
}
}

if (element.is<Decl *>() || element.is<StmtConditionElement *>() ||
element.is<Expr *>() || element.isStmt(StmtKind::Return))
element.walk(refFinder);
Expand Down
66 changes: 64 additions & 2 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2299,9 +2299,71 @@ namespace {
locator);
}

// If we have a type to ascribe to the variable, do so now.
if (oneWayVarType)
// Ascribe a type to the declaration so it's always available to
// constraint system.
if (oneWayVarType) {
CS.setType(var, oneWayVarType);
} else if (externalPatternType) {
// If there is an externally imposed type, that's what the
// declaration is going to be bound to.
CS.setType(var, externalPatternType);
} else {
// Otherwise, let's use the type of the pattern. The type
// of the declaration has to be r-value, so let's add an
// equality constraint if pattern type has any type variables
// that are allowed to be l-value.
bool foundLValueVars = false;

// Note that it wouldn't be always correct to allocate a single type
// variable, that disallows l-value types, to use as a declaration
// type because equality constraint would drop TVO_CanBindToLValue
// from the right-hand side (which is not the case for `OneWayEqual`)
// e.g.:
//
// sturct S { var x, y: Int }
//
// func test(s: S) {
// let (x, y) = (s.x, s.y)
// }
//
// Single type variable approach results in the following constraint:
// `$T_x_y = ($T_s_x, $T_s_y)` where both `$T_s_x` and `$T_s_y` have
// to allow l-value, but `$T_x_y` does not. Early simplication of `=`
// constraint (due to right-hand side being a "concrete" tuple type)
// would drop l-value option from `$T_s_x` and `$T_s_y` which leads to
// a failure during member lookup because `x` and `y` are both
// `@lvalue Int`. To avoid that, declaration type would mimic pattern
// type with all l-value options stripped, so the equality constraint
// becomes `($T_x, $_T_y) = ($T_s_x, $T_s_y)` which doesn't result in
// stripping of l-value flag from the right-hand side since
// simplification can only happen when either side is resolved.
auto declTy = varType.transform([&](Type type) -> Type {
if (auto *typeVar = type->getAs<TypeVariableType>()) {
if (typeVar->getImpl().canBindToLValue()) {
foundLValueVars = true;

// Drop l-value from the options but preserve the rest.
auto options = typeVar->getImpl().getRawOptions();
options &= ~TVO_CanBindToLValue;

return CS.createTypeVariable(typeVar->getImpl().getLocator(),
options);
}
}
return type;
});

// If pattern types allows l-value types, let's create an
// equality constraint between r-value only declaration type
// and l-value pattern type that would take care of looking
// through l-values when necessary.
if (foundLValueVars) {
CS.addConstraint(ConstraintKind::Equal, declTy, varType,
CS.getConstraintLocator(locator));
}

CS.setType(var, declTy);
}

return setType(varType);
}
Expand Down
17 changes: 15 additions & 2 deletions lib/Sema/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,21 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const {
}

if (Kind == ConstraintKind::ClosureBodyElement) {
Out << "closure body element ";
getClosureElement().dump(Out);
auto *locator = getLocator();
auto element = getClosureElement();

if (auto patternBindingElt =
locator
->getLastElementAs<LocatorPathElt::PatternBindingElement>()) {
auto *patternBinding = cast<PatternBindingDecl>(element.get<Decl *>());
Out << "pattern binding element @ ";
Out << patternBindingElt->getIndex() << " : ";
patternBinding->getPattern(patternBindingElt->getIndex())->dump(Out);
} else {
Out << "closure body element ";
getClosureElement().dump(Out);
}

return;
}

Expand Down
9 changes: 9 additions & 0 deletions lib/Sema/ConstraintLocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ unsigned LocatorPathElt::getNewSummaryFlags() const {
case ConstraintLocator::ClosureBodyElement:
case ConstraintLocator::PackType:
case ConstraintLocator::PackElement:
case ConstraintLocator::PatternBindingElement:
return 0;

case ConstraintLocator::FunctionArgument:
Expand Down Expand Up @@ -578,6 +579,14 @@ void ConstraintLocator::dump(SourceManager *sm, raw_ostream &out) const {
out << "pack element #" << llvm::utostr(packElt.getIndex());
break;
}

case PatternBindingElement: {
auto patternBindingElt =
elt.castTo<LocatorPathElt::PatternBindingElement>();
out << "pattern binding element #"
<< llvm::utostr(patternBindingElt.getIndex());
break;
}
}
}
out << ']';
Expand Down
17 changes: 8 additions & 9 deletions test/Constraints/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct CC {}
func callCC<U>(_ f: (CC) -> U) -> () {}

func typeCheckMultiStmtClosureCrash() {
callCC { // expected-error {{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{none}}
callCC {
_ = $0
return 1
}
Expand Down Expand Up @@ -312,24 +312,23 @@ func testAcceptNothingToInt(ac1: @autoclosure () -> Int) {
struct Thing {
init?() {}
}
// This throws a compiler error
let things = Thing().map { thing in // expected-error {{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{34-34=-> <#Result#> }}
// Commenting out this makes it compile

let things = Thing().map { thing in
_ = thing
return thing
}


// <rdar://problem/21675896> QoI: [Closure return type inference] Swift cannot find members for the result of inlined lambdas with branches
func r21675896(file : String) {
let x: String = { // expected-error {{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{20-20= () -> <#Result#> in }}
let x: String = {
if true {
return "foo"
}
else {
return file
}
}().pathExtension
}().pathExtension // expected-error {{value of type 'String' has no member 'pathExtension'}}
}


Expand Down Expand Up @@ -360,7 +359,7 @@ func someGeneric19997471<T>(_ x: T) {


// <rdar://problem/20921068> Swift fails to compile: [0].map() { _ in let r = (1,2).0; return r }
[0].map { // expected-error {{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{5-5=-> <#Result#> }}
let _ = [0].map {
_ in
let r = (1,2).0
return r
Expand Down Expand Up @@ -408,7 +407,7 @@ func r20789423() {
print(p.f(p)()) // expected-error {{cannot convert value of type 'C' to expected argument type 'Int'}}
// expected-error@-1:11 {{cannot call value of non-function type '()'}}

let _f = { (v: Int) in // expected-error {{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{23-23=-> <#Result#> }}
let _f = { (v: Int) in
print("a")
return "hi"
}
Expand Down Expand Up @@ -1127,7 +1126,7 @@ func rdar76058892() {
func experiment(arr: [S]?) {
test { // expected-error {{contextual closure type '() -> String' expects 0 arguments, but 1 was used in closure body}}
if let arr = arr {
arr.map($0.test) // expected-note {{anonymous closure parameter '$0' is used here}}
arr.map($0.test) // expected-note {{anonymous closure parameter '$0' is used here}} // expected-error {{generic parameter 'T' could not be inferred}}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions test/Constraints/diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func ***~(_: Int, _: String) { }
i ***~ i // expected-error{{cannot convert value of type 'Int' to expected argument type 'String'}}

@available(*, unavailable, message: "call the 'map()' method on the sequence")
public func myMap<C : Collection, T>(
public func myMap<C : Collection, T>( // expected-note {{'myMap' has been explicitly marked unavailable here}}
_ source: C, _ transform: (C.Iterator.Element) -> T
) -> [T] {
fatalError("unavailable function can't be called")
Expand All @@ -161,7 +161,7 @@ public func myMap<T, U>(_ x: T?, _ f: (T) -> U) -> U? {

// <rdar://problem/20142523>
func rdar20142523() {
myMap(0..<10, { x in // expected-error{{cannot infer return type for closure with multiple statements; add explicit type to disambiguate}} {{21-21=-> <#Result#> }} {{educational-notes=complex-closure-inference}}
_ = myMap(0..<10, { x in // expected-error {{'myMap' is unavailable: call the 'map()' method on the sequence}}
()
return x
})
Expand Down
14 changes: 11 additions & 3 deletions test/Constraints/members.swift
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,10 @@ func rdar50679161() {

func foo() {
_ = { () -> Void in
// Missing `.self` or `init` is not diagnosed here because there are errors in
// `if let` statement and `MiscDiagnostics` only run if the body is completely valid.
var foo = S
// expected-error@-1 {{expected member name or constructor call after type name}}
// expected-note@-2 {{add arguments after the type to construct a value of the type}}
// expected-note@-3 {{use '.self' to reference the type object}}

if let v = Int?(1) {
var _ = Q(
a: v + foo.w,
Expand All @@ -610,6 +610,14 @@ func rdar50679161() {
)
}
}

_ = { () -> Void in
var foo = S
// expected-error@-1 {{expected member name or constructor call after type name}}
// expected-note@-2 {{add arguments after the type to construct a value of the type}}
// expected-note@-3 {{use '.self' to reference the type object}}
print(foo)
}
}
}

Expand Down
Loading