Skip to content

Commit 59154d6

Browse files
committed
[ConstraintSystem] Support solving expression patterns via injecting call to ~= operator
Augment the constraint solver to fallback to implicit `~=` application when member couldn't be found for `EnumElement` patterns because `case` statement should be able to match enum member directly, as well as through an implicit `~=` operator application.
1 parent 16f6a2e commit 59154d6

File tree

5 files changed

+183
-4
lines changed

5 files changed

+183
-4
lines changed

lib/Sema/CSApply.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9007,6 +9007,31 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
90079007
Type patternType = solution.simplifyType(solution.getType(info.pattern));
90089008
patternType = patternType->reconstituteSugar(/*recursive=*/false);
90099009

9010+
// Check whether this enum element is resolved via ~= application.
9011+
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9012+
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
9013+
auto *EP = target->getExprPattern();
9014+
auto enumType = solution.simplifyType(EP->getType());
9015+
9016+
auto *matchCall = target->getAsExpr();
9017+
9018+
auto *result = matchCall->walk(*this);
9019+
if (!result)
9020+
return None;
9021+
9022+
{
9023+
auto *matchVar = EP->getMatchVar();
9024+
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9025+
}
9026+
9027+
EP->setMatchExpr(result);
9028+
EP->setType(enumType);
9029+
9030+
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9031+
return target;
9032+
}
9033+
}
9034+
90109035
// Coerce the pattern to its appropriate type.
90119036
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
90129037
patternOptions |= TypeResolutionFlags::OverrideType;

lib/Sema/CSSimplify.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8629,6 +8629,69 @@ fixMemberRef(ConstraintSystem &cs, Type baseTy,
86298629
return nullptr;
86308630
}
86318631

8632+
/// Convert the given enum element pattern into an expression pattern
8633+
/// and synthesize ~= operator application to find the type of the
8634+
/// element.
8635+
static bool inferEnumMemberThroughTildeEqualsOperator(
8636+
ConstraintSystem &cs, EnumElementPattern *pattern, Type enumTy,
8637+
Type elementTy, ConstraintLocator *locator) {
8638+
if (!pattern->hasUnresolvedOriginalExpr())
8639+
return true;
8640+
8641+
auto &DC = cs.DC;
8642+
auto &ctx = cs.getASTContext();
8643+
8644+
// Slots for expression and variable are going to be filled via
8645+
// synthesizing ~= operator application.
8646+
auto *EP = new (ctx) ExprPattern(pattern->getUnresolvedOriginalExpr(),
8647+
/*matchExpr=*/nullptr, /*matchVar=*/nullptr);
8648+
8649+
auto tildeEqualsApplication =
8650+
TypeChecker::synthesizeTildeEqualsOperatorApplication(EP, DC, enumTy);
8651+
8652+
if (!tildeEqualsApplication)
8653+
return true;
8654+
8655+
VarDecl *matchVar;
8656+
Expr *matchCall;
8657+
8658+
std::tie(matchVar, matchCall) = *tildeEqualsApplication;
8659+
8660+
// result of ~= operator is always a `Bool`.
8661+
auto target = SolutionApplicationTarget::forExprPattern(
8662+
matchCall, DC, EP, ctx.getBoolDecl()->getDeclaredInterfaceType());
8663+
8664+
DiagnosticTransaction diagnostics(ctx.Diags);
8665+
{
8666+
if (cs.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
8667+
/*leaveClosureBodyUnchecked=*/false)) {
8668+
// Skip diagnostics if they are disabled, otherwise it would result in
8669+
// duplicate diagnostics, since this operation is going to be repeated
8670+
// in diagnostic mode.
8671+
if (!cs.shouldAttemptFixes())
8672+
diagnostics.abort();
8673+
8674+
return true;
8675+
}
8676+
}
8677+
8678+
cs.generateConstraints(target, FreeTypeVariableBinding::Disallow);
8679+
8680+
// Sub-expression associated with expression pattern is the enum element
8681+
// access which needs to be connected to the provided element type.
8682+
cs.addConstraint(ConstraintKind::Conversion, cs.getType(EP->getSubExpr()),
8683+
elementTy, cs.getConstraintLocator(EP));
8684+
8685+
// Store the $match variable and binary expression for solution application.
8686+
EP->setMatchVar(matchVar);
8687+
EP->setMatchExpr(matchCall);
8688+
EP->setType(enumTy);
8689+
8690+
cs.setSolutionApplicationTarget(pattern, target);
8691+
8692+
return false;
8693+
}
8694+
86328695
ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
86338696
ConstraintKind kind, Type baseTy, DeclNameRef member, Type memberTy,
86348697
DeclContext *useDC, FunctionRefKind functionRefKind,
@@ -8834,6 +8897,30 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
88348897
(void)recordFix(fix);
88358898
}
88368899

8900+
// If there were no results from a direct enum lookup, let's attempt
8901+
// to resolve this member via ~= operator application.
8902+
if (candidates.empty()) {
8903+
if (auto patternLoc =
8904+
locator->getLastElementAs<LocatorPathElt::PatternMatch>()) {
8905+
if (auto *enumElement =
8906+
dyn_cast<EnumElementPattern>(patternLoc->getPattern())) {
8907+
auto enumType = baseObjTy->getMetatypeInstanceType();
8908+
8909+
// If the synthesis of ~= resulted in errors (i.e. broken stdlib)
8910+
// that would be diagnosed inline, so let's just fall through and
8911+
// let this situation be diagnosed as a missing member.
8912+
auto hadErrors = inferEnumMemberThroughTildeEqualsOperator(
8913+
*this, enumElement, enumType, memberTy, locator);
8914+
8915+
// Let's consider current member constraint solved because it's
8916+
// replaced by a new set of constraints that would resolve member
8917+
// type.
8918+
if (!hadErrors)
8919+
return SolutionKind::Solved;
8920+
}
8921+
}
8922+
}
8923+
88378924
if (!candidates.empty()) {
88388925
addOverloadSet(candidates, locator);
88398926
return SolutionKind::Solved;

test/Constraints/result_builder.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,3 +882,40 @@ func test_weak_optionality_stays_the_same() {
882882
}
883883
}
884884
}
885+
886+
enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
887+
case known(Wrapped)
888+
889+
static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
890+
switch rhs {
891+
case .known(let wrapped):
892+
return wrapped == lhs
893+
}
894+
}
895+
}
896+
897+
func test_custom_tilde_equals_operator_matching() {
898+
@resultBuilder
899+
struct Builder {
900+
static func buildBlock<T>(_ t: T) -> T { t }
901+
static func buildEither<T>(first: T) -> T { first }
902+
static func buildEither<T>(second: T) -> T { second }
903+
}
904+
905+
enum TildeTest : String {
906+
case test = "test"
907+
}
908+
909+
struct S {}
910+
911+
struct MyView {
912+
var entry: WrapperEnum<TildeTest>
913+
914+
@Builder var body: S {
915+
switch entry {
916+
case .test: S() // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
917+
case .known(_): S() // Ok - `.known` comes directly from `WrapperEnum`
918+
}
919+
}
920+
}
921+
}

test/Constraints/result_builder_diags.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,8 @@ func getSomeEnumOverloaded(_: Int) -> E2 { return .b(0, nil) }
442442

443443
func testOverloadedSwitch() {
444444
tuplify(true) { c in
445-
// FIXME: Bad source location.
446-
switch getSomeEnumOverloaded(17) { // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
447-
case .a:
445+
switch getSomeEnumOverloaded(17) {
446+
case .a: // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
448447
"a"
449448
default:
450449
"default"
@@ -775,13 +774,14 @@ func test_rdar65667992() {
775774
@Builder var body: S {
776775
switch entry { // expected-error {{type 'E' has no member 'unset'}}
777776
case .set(_, _): S()
778-
case .unset(_): S()
777+
case .unset(_): S() // expected-error {{'_' can only appear in a pattern or on the left side of an assignment}}
779778
default: S()
780779
}
781780
}
782781
}
783782
}
784783

784+
785785
func test_weak_with_nonoptional_type() {
786786
class X {
787787
func test() -> Int { 0 }

test/expr/closure/multi_statement.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,33 @@ func test_workaround_for_optional_void_result() {
178178
let _ = $0
179179
}
180180
}
181+
182+
enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
183+
case known(Wrapped)
184+
185+
static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
186+
switch rhs {
187+
case .known(let wrapped):
188+
return wrapped == lhs
189+
}
190+
}
191+
}
192+
193+
func test_custom_tilde_equals_operator_matching() {
194+
enum TildeTest : String {
195+
case test = "test"
196+
case otherTest = ""
197+
}
198+
199+
func test(_: (WrapperEnum<TildeTest>) -> Void) {}
200+
201+
test { v in
202+
print(v)
203+
204+
switch v {
205+
case .test: break // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
206+
case .otherTest: break // Ok although `.otherTest` comes from `TildeTest` instead of `WrapperEnum`
207+
case .known(_): break // Ok - `.known` comes from `WrapperEnum`
208+
}
209+
}
210+
}

0 commit comments

Comments
 (0)