Skip to content

Commit a44f501

Browse files
committed
Allow function-function conversions for keypath literals
Remove keypath subtype asserts; always use cached root type Add tests for keypaths converted to funcs with inout param Add unit test for overload selection
1 parent f24976b commit a44f501

File tree

9 files changed

+333
-72
lines changed

9 files changed

+333
-72
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,12 @@ ERROR(expr_smart_keypath_application_type_mismatch,none,
698698
"key path of type %0 cannot be applied to a base of type %1",
699699
(Type, Type))
700700
ERROR(expr_keypath_root_type_mismatch, none,
701+
"key path root type %0 cannot be converted to contextual type %1",
702+
(Type, Type))
703+
ERROR(expr_keypath_type_mismatch, none,
704+
"key path of type %0 cannot be converted to contextual type %1",
705+
(Type, Type))
706+
ERROR(expr_keypath_application_root_type_mismatch, none,
701707
"key path with root type %0 cannot be applied to a base of type %1",
702708
(Type, Type))
703709
ERROR(expr_swift_keypath_anyobject_root,none,

include/swift/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,11 @@ class Solution {
16981698
/// Retrieve the type of the \p ComponentIndex-th component in \p KP.
16991699
Type getType(const KeyPathExpr *KP, unsigned ComponentIndex) const;
17001700

1701+
TypeVariableType *getKeyPathRootType(const KeyPathExpr *keyPath) const;
1702+
1703+
TypeVariableType *
1704+
getKeyPathRootTypeIfAvailable(const KeyPathExpr *keyPath) const;
1705+
17011706
/// Retrieve the type of the given node as recorded in this solution
17021707
/// and resolve all of the type variables in contains to form a fully
17031708
/// "resolved" concrete type.

lib/Sema/CSApply.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5013,20 +5013,18 @@ namespace {
50135013
// Resolve each of the components.
50145014
bool didOptionalChain = false;
50155015
bool isFunctionType = false;
5016-
Type baseTy, leafTy;
5016+
auto baseTy = cs.simplifyType(solution.getKeyPathRootType(E));
5017+
Type leafTy;
50175018
Type exprType = cs.getType(E);
50185019
if (auto fnTy = exprType->getAs<FunctionType>()) {
5019-
baseTy = fnTy->getParams()[0].getParameterType();
50205020
leafTy = fnTy->getResult();
50215021
isFunctionType = true;
50225022
} else if (auto *existential = exprType->getAs<ExistentialType>()) {
50235023
auto layout = existential->getExistentialLayout();
50245024
auto keyPathTy = layout.explicitSuperclass->castTo<BoundGenericType>();
5025-
baseTy = keyPathTy->getGenericArgs()[0];
50265025
leafTy = keyPathTy->getGenericArgs()[1];
50275026
} else {
50285027
auto keyPathTy = exprType->castTo<BoundGenericType>();
5029-
baseTy = keyPathTy->getGenericArgs()[0];
50305028
leafTy = keyPathTy->getGenericArgs()[1];
50315029
}
50325030

@@ -5145,13 +5143,11 @@ namespace {
51455143
assert(!resolvedComponents.empty());
51465144
componentTy = resolvedComponents.back().getComponentType();
51475145
}
5148-
5146+
51495147
// Wrap a non-optional result if there was chaining involved.
51505148
if (didOptionalChain && componentTy &&
51515149
!componentTy->hasUnresolvedType() &&
51525150
!componentTy->getWithoutSpecifierType()->isEqual(leafTy)) {
5153-
assert(leafTy->getOptionalObjectType()->isEqual(
5154-
componentTy->getWithoutSpecifierType()));
51555151
auto component = KeyPathExpr::Component::forOptionalWrap(leafTy);
51565152
resolvedComponents.push_back(component);
51575153
componentTy = leafTy;
@@ -5164,11 +5160,6 @@ namespace {
51645160
// See whether there's an equivalent ObjC key path string we can produce
51655161
// for interop purposes.
51665162
checkAndSetObjCKeyPathString(E);
5167-
5168-
// The final component type ought to line up with the leaf type of the
5169-
// key path.
5170-
assert(!componentTy || componentTy->hasUnresolvedType()
5171-
|| componentTy->getWithoutSpecifierType()->isEqual(leafTy));
51725163

51735164
if (!isFunctionType)
51745165
return E;
@@ -9844,6 +9835,21 @@ Type Solution::getType(const KeyPathExpr *KP, unsigned I) const {
98449835
return keyPathComponentTypes.find(std::make_pair(KP, I))->second;
98459836
}
98469837

9838+
TypeVariableType *
9839+
Solution::getKeyPathRootType(const KeyPathExpr *keyPath) const {
9840+
auto result = getKeyPathRootTypeIfAvailable(keyPath);
9841+
assert(result);
9842+
return result;
9843+
}
9844+
9845+
TypeVariableType *
9846+
Solution::getKeyPathRootTypeIfAvailable(const KeyPathExpr *keyPath) const {
9847+
auto result = KeyPaths.find(keyPath);
9848+
if (result != KeyPaths.end())
9849+
return std::get<0>(result->second);
9850+
return nullptr;
9851+
}
9852+
98479853
Type Solution::getResolvedType(ASTNode node) const {
98489854
return simplifyType(getType(node));
98499855
}

lib/Sema/CSDiagnostics.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,9 +2496,13 @@ bool ContextualFailure::diagnoseAsError() {
24962496

24972497
if (path.empty()) {
24982498
if (auto *KPE = getAsExpr<KeyPathExpr>(anchor)) {
2499-
emitDiagnosticAt(KPE->getLoc(),
2500-
diag::expr_keypath_type_covert_to_contextual_type,
2501-
getFromType(), getToType());
2499+
Diag<Type, Type> diag;
2500+
if (auto ctxDiag = getDiagnosticFor(CTP, getToType())) {
2501+
diag = *ctxDiag;
2502+
} else {
2503+
diag = diag::expr_keypath_type_mismatch;
2504+
}
2505+
emitDiagnosticAt(KPE->getLoc(), diag, getFromType(), getToType());
25022506
return true;
25032507
}
25042508

@@ -2749,9 +2753,14 @@ bool ContextualFailure::diagnoseAsError() {
27492753
break;
27502754
}
27512755

2756+
case ConstraintLocator::FunctionResult:
27522757
case ConstraintLocator::KeyPathValue: {
2753-
diagnostic = diag::expr_keypath_value_covert_to_contextual_type;
2754-
break;
2758+
if (auto *KPE = getAsExpr<KeyPathExpr>(anchor)) {
2759+
diagnostic = diag::expr_keypath_value_covert_to_contextual_type;
2760+
break;
2761+
} else {
2762+
return false;
2763+
}
27552764
}
27562765

27572766
default:
@@ -8249,13 +8258,24 @@ bool CoercionAsForceCastFailure::diagnoseAsError() {
82498258

82508259
bool KeyPathRootTypeMismatchFailure::diagnoseAsError() {
82518260
auto locator = getLocator();
8261+
auto anchor = locator->getAnchor();
82528262
assert(locator->isKeyPathRoot() && "Expected a key path root");
8253-
8254-
auto baseType = getFromType();
8255-
auto rootType = getToType();
82568263

8257-
emitDiagnostic(diag::expr_keypath_root_type_mismatch,
8258-
rootType, baseType);
8264+
8265+
8266+
if (isExpr<KeyPathApplicationExpr>(anchor) || isExpr<SubscriptExpr>(anchor)) {
8267+
auto baseType = getFromType();
8268+
auto rootType = getToType();
8269+
8270+
emitDiagnostic(diag::expr_keypath_application_root_type_mismatch,
8271+
rootType, baseType);
8272+
} else {
8273+
auto rootType = getFromType();
8274+
auto expectedType = getToType();
8275+
8276+
emitDiagnostic(diag::expr_keypath_root_type_mismatch, rootType,
8277+
expectedType);
8278+
}
82598279
return true;
82608280
}
82618281

lib/Sema/CSSimplify.cpp

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5209,6 +5209,13 @@ bool ConstraintSystem::repairFailures(
52095209
});
52105210
};
52115211

5212+
auto hasAnyRestriction = [&]() {
5213+
return llvm::any_of(conversionsOrFixes,
5214+
[](const RestrictionOrFix &correction) {
5215+
return bool(correction.getRestriction());
5216+
});
5217+
};
5218+
52125219
// Check whether this is a tuple with a single unlabeled element
52135220
// i.e. `(_: Int)` and return type of that element if so. Note that
52145221
// if the element is pack expansion type the tuple is significant.
@@ -5234,6 +5241,40 @@ bool ConstraintSystem::repairFailures(
52345241
return true;
52355242
}
52365243

5244+
auto maybeRepairKeyPathResultFailure = [&](KeyPathExpr *kpExpr) {
5245+
if (lhs->isPlaceholder() || rhs->isPlaceholder())
5246+
return true;
5247+
if (lhs->isTypeVariableOrMember() || rhs->isTypeVariableOrMember())
5248+
return false;
5249+
5250+
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality) ||
5251+
hasConversionOrRestriction(ConversionRestrictionKind::ValueToOptional))
5252+
return false;
5253+
5254+
auto i = kpExpr->getComponents().size() - 1;
5255+
auto lastCompLoc =
5256+
getConstraintLocator(kpExpr, LocatorPathElt::KeyPathComponent(i));
5257+
if (hasFixFor(lastCompLoc, FixKind::AllowTypeOrInstanceMember))
5258+
return true;
5259+
5260+
auto *keyPathLoc = getConstraintLocator(anchor);
5261+
5262+
if (hasFixFor(keyPathLoc))
5263+
return true;
5264+
5265+
if (auto contextualInfo = getContextualTypeInfo(anchor)) {
5266+
if (hasFixFor(getConstraintLocator(
5267+
keyPathLoc,
5268+
LocatorPathElt::ContextualType(contextualInfo->purpose))))
5269+
return true;
5270+
}
5271+
5272+
conversionsOrFixes.push_back(IgnoreContextualType::create(
5273+
*this, lhs, rhs,
5274+
getConstraintLocator(keyPathLoc, ConstraintLocator::KeyPathValue)));
5275+
return true;
5276+
};
5277+
52375278
if (path.empty()) {
52385279
if (!anchor)
52395280
return false;
@@ -5253,9 +5294,9 @@ bool ConstraintSystem::repairFailures(
52535294
// instance fix recorded.
52545295
if (auto *kpExpr = getAsExpr<KeyPathExpr>(anchor)) {
52555296
if (isKnownKeyPathType(lhs) && isKnownKeyPathType(rhs)) {
5256-
// If we have keypath capabilities for both sides and one of the bases
5257-
// is unresolved, it is too early to record fix.
5258-
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality))
5297+
// If we have a conversion happening here, we should let fix happen in
5298+
// simplifyRestrictedConstraint.
5299+
if (hasAnyRestriction())
52595300
return false;
52605301
}
52615302

@@ -5655,10 +5696,7 @@ bool ConstraintSystem::repairFailures(
56555696

56565697
// If there are any restrictions here we need to wait and let
56575698
// `simplifyRestrictedConstraintImpl` handle them.
5658-
if (llvm::any_of(conversionsOrFixes,
5659-
[](const RestrictionOrFix &correction) {
5660-
return bool(correction.getRestriction());
5661-
}))
5699+
if (hasAnyRestriction())
56625700
break;
56635701

56645702
if (auto *fix = fixPropertyWrapperFailure(
@@ -6077,10 +6115,7 @@ bool ConstraintSystem::repairFailures(
60776115

60786116
// If there are any restrictions here we need to wait and let
60796117
// `simplifyRestrictedConstraintImpl` handle them.
6080-
if (llvm::any_of(conversionsOrFixes,
6081-
[](const RestrictionOrFix &correction) {
6082-
return bool(correction.getRestriction());
6083-
}))
6118+
if (hasAnyRestriction())
60846119
break;
60856120

60866121
// `lhs` - is an result type and `rhs` is a contextual type.
@@ -6099,6 +6134,10 @@ bool ConstraintSystem::repairFailures(
60996134
return true;
61006135
}
61016136

6137+
if (auto *kpExpr = getAsExpr<KeyPathExpr>(anchor)) {
6138+
return maybeRepairKeyPathResultFailure(kpExpr);
6139+
}
6140+
61026141
auto *loc = getConstraintLocator(anchor, {path.begin(), path.end() - 1});
61036142
// If this is a mismatch between contextual type and (trailing)
61046143
// closure with explicitly specified result type let's record it
@@ -6670,37 +6709,9 @@ bool ConstraintSystem::repairFailures(
66706709
return true;
66716710
}
66726711
case ConstraintLocator::KeyPathValue: {
6673-
if (lhs->isPlaceholder() || rhs->isPlaceholder())
6674-
return true;
6675-
if (lhs->isTypeVariableOrMember() || rhs->isTypeVariableOrMember())
6676-
break;
6677-
6678-
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality) ||
6679-
hasConversionOrRestriction(ConversionRestrictionKind::ValueToOptional))
6680-
return false;
6681-
6682-
auto kpExpr = castToExpr<KeyPathExpr>(anchor);
6683-
auto i = kpExpr->getComponents().size() - 1;
6684-
auto lastCompLoc =
6685-
getConstraintLocator(kpExpr, LocatorPathElt::KeyPathComponent(i));
6686-
if (hasFixFor(lastCompLoc, FixKind::AllowTypeOrInstanceMember))
6712+
if (maybeRepairKeyPathResultFailure(getAsExpr<KeyPathExpr>(anchor)))
66876713
return true;
66886714

6689-
auto *keyPathLoc = getConstraintLocator(anchor);
6690-
6691-
if (hasFixFor(keyPathLoc))
6692-
return true;
6693-
6694-
if (auto contextualInfo = getContextualTypeInfo(anchor)) {
6695-
if (hasFixFor(getConstraintLocator(
6696-
keyPathLoc,
6697-
LocatorPathElt::ContextualType(contextualInfo->purpose))))
6698-
return true;
6699-
}
6700-
6701-
conversionsOrFixes.push_back(IgnoreContextualType::create(
6702-
*this, lhs, rhs,
6703-
getConstraintLocator(keyPathLoc, ConstraintLocator::KeyPathValue)));
67046715
break;
67056716
}
67066717
default:
@@ -12221,12 +12232,26 @@ ConstraintSystem::simplifyKeyPathConstraint(
1222112232

1222212233
if (auto fnTy = contextualTy->getAs<FunctionType>()) {
1222312234
assert(fnTy->getParams().size() == 1);
12224-
// Match up the root and value types to the function's param and return
12225-
// types. Note that we're using the type of the parameter as referenced
12226-
// from inside the function body as we'll be transforming the code into:
12227-
// { root in root[keyPath: kp] }.
12228-
contextualRootTy = fnTy->getParams()[0].getParameterType();
12229-
contextualValueTy = fnTy->getResult();
12235+
// Key paths may be converted to a function of compatible type. We will
12236+
// later form from this key path an implicit closure of the form
12237+
// `{ root in root[keyPath: kp] }` so any conversions that are valid with
12238+
// a source type of `(Root) -> Value` should be valid here too.
12239+
auto rootParam = AnyFunctionType::Param(rootTy);
12240+
auto kpFnTy = FunctionType::get(rootParam, valueTy, fnTy->getExtInfo());
12241+
12242+
// Note: because the keypath is applied to `root` as a parameter internal
12243+
// to the closure, we use the function parameter's "parameter type" rather
12244+
// than the raw type. This enables things like:
12245+
// ```
12246+
// let countKeyPath: (String...) -> Int = \.count
12247+
// ```
12248+
auto paramTy = fnTy->getParams()[0].getParameterType();
12249+
auto paramParam = AnyFunctionType::Param(paramTy);
12250+
auto paramFnTy = FunctionType::get(paramParam, fnTy->getResult(),
12251+
fnTy->getExtInfo());
12252+
12253+
return matchTypes(kpFnTy, paramFnTy, ConstraintKind::Conversion, subflags,
12254+
locator).isSuccess();
1223012255
}
1223112256

1223212257
assert(contextualRootTy && contextualValueTy);

test/Constraints/keypath.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ func testVariadicKeypathAsFunc() {
8181

8282
// These are not okay, the KeyPath should have a base that matches the
8383
// internal parameter type of the function, i.e (S...).
84-
let _: (S...) -> Int = \S.i // expected-error {{key path with root type 'S...' cannot be applied to a base of type 'S'}}
85-
takesVariadicFnWithGenericRet(\S.i) // expected-error {{key path with root type 'S...' cannot be applied to a base of type 'S'}}
84+
let _: (S...) -> Int = \S.i // expected-error {{key path root type 'S' cannot be converted to contextual type 'S...'}}
85+
takesVariadicFnWithGenericRet(\S.i) // expected-error {{key path root type 'S' cannot be converted to contextual type 'S...'}}
8686
}
8787

8888
// rdar://problem/54322807
@@ -231,7 +231,7 @@ func issue_65965() {
231231
let refKP: ReferenceWritableKeyPath<S, String>
232232
refKP = \.s
233233
// expected-error@-1 {{cannot convert key path type 'WritableKeyPath<S, String>' to contextual type 'ReferenceWritableKeyPath<S, String>'}}
234-
234+
235235
let writeKP: WritableKeyPath<S, String>
236236
writeKP = \.v
237237
// expected-error@-1 {{cannot convert key path type 'KeyPath<S, String>' to contextual type 'WritableKeyPath<S, String>'}}

0 commit comments

Comments
 (0)