Skip to content

Commit 26d8552

Browse files
authored
Merge pull request #41589 from xedin/handle-expr-patterns-in-the-solver
[ConstraintSystem] Support solving expression patterns via injecting call to `~=` operator
2 parents 064eb27 + 59154d6 commit 26d8552

File tree

7 files changed

+232
-27
lines changed

7 files changed

+232
-27
lines changed

lib/Sema/CSApply.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9012,6 +9012,31 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
90129012
Type patternType = solution.simplifyType(solution.getType(info.pattern));
90139013
patternType = patternType->reconstituteSugar(/*recursive=*/false);
90149014

9015+
// Check whether this enum element is resolved via ~= application.
9016+
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9017+
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
9018+
auto *EP = target->getExprPattern();
9019+
auto enumType = solution.simplifyType(EP->getType());
9020+
9021+
auto *matchCall = target->getAsExpr();
9022+
9023+
auto *result = matchCall->walk(*this);
9024+
if (!result)
9025+
return None;
9026+
9027+
{
9028+
auto *matchVar = EP->getMatchVar();
9029+
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9030+
}
9031+
9032+
EP->setMatchExpr(result);
9033+
EP->setType(enumType);
9034+
9035+
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9036+
return target;
9037+
}
9038+
}
9039+
90159040
// Coerce the pattern to its appropriate type.
90169041
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
90179042
patternOptions |= TypeResolutionFlags::OverrideType;

lib/Sema/CSSimplify.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8644,6 +8644,69 @@ fixMemberRef(ConstraintSystem &cs, Type baseTy,
86448644
return nullptr;
86458645
}
86468646

8647+
/// Convert the given enum element pattern into an expression pattern
8648+
/// and synthesize ~= operator application to find the type of the
8649+
/// element.
8650+
static bool inferEnumMemberThroughTildeEqualsOperator(
8651+
ConstraintSystem &cs, EnumElementPattern *pattern, Type enumTy,
8652+
Type elementTy, ConstraintLocator *locator) {
8653+
if (!pattern->hasUnresolvedOriginalExpr())
8654+
return true;
8655+
8656+
auto &DC = cs.DC;
8657+
auto &ctx = cs.getASTContext();
8658+
8659+
// Slots for expression and variable are going to be filled via
8660+
// synthesizing ~= operator application.
8661+
auto *EP = new (ctx) ExprPattern(pattern->getUnresolvedOriginalExpr(),
8662+
/*matchExpr=*/nullptr, /*matchVar=*/nullptr);
8663+
8664+
auto tildeEqualsApplication =
8665+
TypeChecker::synthesizeTildeEqualsOperatorApplication(EP, DC, enumTy);
8666+
8667+
if (!tildeEqualsApplication)
8668+
return true;
8669+
8670+
VarDecl *matchVar;
8671+
Expr *matchCall;
8672+
8673+
std::tie(matchVar, matchCall) = *tildeEqualsApplication;
8674+
8675+
// result of ~= operator is always a `Bool`.
8676+
auto target = SolutionApplicationTarget::forExprPattern(
8677+
matchCall, DC, EP, ctx.getBoolDecl()->getDeclaredInterfaceType());
8678+
8679+
DiagnosticTransaction diagnostics(ctx.Diags);
8680+
{
8681+
if (cs.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
8682+
/*leaveClosureBodyUnchecked=*/false)) {
8683+
// Skip diagnostics if they are disabled, otherwise it would result in
8684+
// duplicate diagnostics, since this operation is going to be repeated
8685+
// in diagnostic mode.
8686+
if (!cs.shouldAttemptFixes())
8687+
diagnostics.abort();
8688+
8689+
return true;
8690+
}
8691+
}
8692+
8693+
cs.generateConstraints(target, FreeTypeVariableBinding::Disallow);
8694+
8695+
// Sub-expression associated with expression pattern is the enum element
8696+
// access which needs to be connected to the provided element type.
8697+
cs.addConstraint(ConstraintKind::Conversion, cs.getType(EP->getSubExpr()),
8698+
elementTy, cs.getConstraintLocator(EP));
8699+
8700+
// Store the $match variable and binary expression for solution application.
8701+
EP->setMatchVar(matchVar);
8702+
EP->setMatchExpr(matchCall);
8703+
EP->setType(enumTy);
8704+
8705+
cs.setSolutionApplicationTarget(pattern, target);
8706+
8707+
return false;
8708+
}
8709+
86478710
ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
86488711
ConstraintKind kind, Type baseTy, DeclNameRef member, Type memberTy,
86498712
DeclContext *useDC, FunctionRefKind functionRefKind,
@@ -8849,6 +8912,30 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
88498912
(void)recordFix(fix);
88508913
}
88518914

8915+
// If there were no results from a direct enum lookup, let's attempt
8916+
// to resolve this member via ~= operator application.
8917+
if (candidates.empty()) {
8918+
if (auto patternLoc =
8919+
locator->getLastElementAs<LocatorPathElt::PatternMatch>()) {
8920+
if (auto *enumElement =
8921+
dyn_cast<EnumElementPattern>(patternLoc->getPattern())) {
8922+
auto enumType = baseObjTy->getMetatypeInstanceType();
8923+
8924+
// If the synthesis of ~= resulted in errors (i.e. broken stdlib)
8925+
// that would be diagnosed inline, so let's just fall through and
8926+
// let this situation be diagnosed as a missing member.
8927+
auto hadErrors = inferEnumMemberThroughTildeEqualsOperator(
8928+
*this, enumElement, enumType, memberTy, locator);
8929+
8930+
// Let's consider current member constraint solved because it's
8931+
// replaced by a new set of constraints that would resolve member
8932+
// type.
8933+
if (!hadErrors)
8934+
return SolutionKind::Solved;
8935+
}
8936+
}
8937+
}
8938+
88528939
if (!candidates.empty()) {
88538940
addOverloadSet(candidates, locator);
88548941
return SolutionKind::Solved;

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -849,14 +849,48 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
849849
"typecheck-expr-pattern", EP);
850850
PrettyStackTracePattern stackTrace(Context, "type-checking", EP);
851851

852+
auto tildeEqualsApplication =
853+
synthesizeTildeEqualsOperatorApplication(EP, DC, rhsType);
854+
855+
if (!tildeEqualsApplication)
856+
return true;
857+
858+
VarDecl *matchVar;
859+
Expr *matchCall;
860+
861+
std::tie(matchVar, matchCall) = *tildeEqualsApplication;
862+
863+
// Result of `~=` should always be a boolean.
864+
auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType();
865+
auto target = SolutionApplicationTarget::forExprPattern(matchCall, DC, EP,
866+
contextualTy);
867+
868+
// Check the expression as a condition.
869+
auto result = typeCheckExpression(target);
870+
if (!result)
871+
return true;
872+
873+
// Save the synthesized $match variable in the pattern.
874+
EP->setMatchVar(matchVar);
875+
// Save the type-checked expression in the pattern.
876+
EP->setMatchExpr(result->getAsExpr());
877+
// Set the type on the pattern.
878+
EP->setType(rhsType);
879+
return false;
880+
}
881+
882+
Optional<std::pair<VarDecl *, BinaryExpr *>>
883+
TypeChecker::synthesizeTildeEqualsOperatorApplication(ExprPattern *EP,
884+
DeclContext *DC,
885+
Type enumType) {
886+
auto &Context = DC->getASTContext();
852887
// Create a 'let' binding to stand in for the RHS value.
853888
auto *matchVar =
854889
new (Context) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let,
855890
EP->getLoc(), Context.Id_PatternMatchVar, DC);
856-
matchVar->setInterfaceType(rhsType->mapTypeOutOfContext());
891+
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
857892

858893
matchVar->setImplicit();
859-
EP->setMatchVar(matchVar);
860894

861895
// Find '~=' operators for the match.
862896
auto matchLookup =
@@ -866,19 +900,19 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
866900
auto &diags = DC->getASTContext().Diags;
867901
if (!matchLookup) {
868902
diags.diagnose(EP->getLoc(), diag::no_match_operator);
869-
return true;
903+
return None;
870904
}
871-
905+
872906
SmallVector<ValueDecl*, 4> choices;
873907
for (auto &result : matchLookup) {
874908
choices.push_back(result.getValueDecl());
875909
}
876-
910+
877911
if (choices.empty()) {
878912
diags.diagnose(EP->getLoc(), diag::no_match_operator);
879-
return true;
913+
return None;
880914
}
881-
915+
882916
// Build the 'expr ~= var' expression.
883917
// FIXME: Compound name locations.
884918
auto *matchOp =
@@ -890,24 +924,10 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
890924
auto *matchVarRef = new (Context) DeclRefExpr(matchVar,
891925
DeclNameLoc(EP->getEndLoc()),
892926
/*Implicit=*/true);
893-
Expr *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp,
927+
auto *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp,
894928
matchVarRef, /*implicit*/ true);
895929

896-
// Result of `~=` should always be a boolean.
897-
auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType();
898-
auto target = SolutionApplicationTarget::forExprPattern(matchCall, DC, EP,
899-
contextualTy);
900-
901-
// Check the expression as a condition.
902-
auto result = typeCheckExpression(target);
903-
if (!result)
904-
return true;
905-
906-
// Save the type-checked expression in the pattern.
907-
EP->setMatchExpr(result->getAsExpr());
908-
// Set the type on the pattern.
909-
EP->setType(rhsType);
910-
return false;
930+
return std::make_pair(matchVar, matchCall);
911931
}
912932

913933
static Type replaceArchetypesWithTypeVariables(ConstraintSystem &cs,

lib/Sema/TypeChecker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,12 @@ Pattern *coercePatternToType(ContextualPattern pattern, Type type,
677677
TypeResolutionOptions options);
678678
bool typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, Type type);
679679

680+
/// Synthesize ~= operator application used to infer enum members
681+
/// in `case` patterns.
682+
Optional<std::pair<VarDecl *, BinaryExpr *>>
683+
synthesizeTildeEqualsOperatorApplication(ExprPattern *EP, DeclContext *DC,
684+
Type enumType);
685+
680686
/// Coerce the specified parameter list of a ClosureExpr to the specified
681687
/// contextual type.
682688
void coerceParameterListToType(ParameterList *P, AnyFunctionType *FN);

test/Constraints/result_builder.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,3 +882,40 @@ func test_weak_optionality_stays_the_same() {
882882
}
883883
}
884884
}
885+
886+
enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
887+
case known(Wrapped)
888+
889+
static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
890+
switch rhs {
891+
case .known(let wrapped):
892+
return wrapped == lhs
893+
}
894+
}
895+
}
896+
897+
func test_custom_tilde_equals_operator_matching() {
898+
@resultBuilder
899+
struct Builder {
900+
static func buildBlock<T>(_ t: T) -> T { t }
901+
static func buildEither<T>(first: T) -> T { first }
902+
static func buildEither<T>(second: T) -> T { second }
903+
}
904+
905+
enum TildeTest : String {
906+
case test = "test"
907+
}
908+
909+
struct S {}
910+
911+
struct MyView {
912+
var entry: WrapperEnum<TildeTest>
913+
914+
@Builder var body: S {
915+
switch entry {
916+
case .test: S() // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
917+
case .known(_): S() // Ok - `.known` comes directly from `WrapperEnum`
918+
}
919+
}
920+
}
921+
}

test/Constraints/result_builder_diags.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,8 @@ func getSomeEnumOverloaded(_: Int) -> E2 { return .b(0, nil) }
442442

443443
func testOverloadedSwitch() {
444444
tuplify(true) { c in
445-
// FIXME: Bad source location.
446-
switch getSomeEnumOverloaded(17) { // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
447-
case .a:
445+
switch getSomeEnumOverloaded(17) {
446+
case .a: // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
448447
"a"
449448
default:
450449
"default"
@@ -775,13 +774,14 @@ func test_rdar65667992() {
775774
@Builder var body: S {
776775
switch entry { // expected-error {{type 'E' has no member 'unset'}}
777776
case .set(_, _): S()
778-
case .unset(_): S()
777+
case .unset(_): S() // expected-error {{'_' can only appear in a pattern or on the left side of an assignment}}
779778
default: S()
780779
}
781780
}
782781
}
783782
}
784783

784+
785785
func test_weak_with_nonoptional_type() {
786786
class X {
787787
func test() -> Int { 0 }

test/expr/closure/multi_statement.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,33 @@ func test_workaround_for_optional_void_result() {
178178
let _ = $0
179179
}
180180
}
181+
182+
enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
183+
case known(Wrapped)
184+
185+
static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
186+
switch rhs {
187+
case .known(let wrapped):
188+
return wrapped == lhs
189+
}
190+
}
191+
}
192+
193+
func test_custom_tilde_equals_operator_matching() {
194+
enum TildeTest : String {
195+
case test = "test"
196+
case otherTest = ""
197+
}
198+
199+
func test(_: (WrapperEnum<TildeTest>) -> Void) {}
200+
201+
test { v in
202+
print(v)
203+
204+
switch v {
205+
case .test: break // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
206+
case .otherTest: break // Ok although `.otherTest` comes from `TildeTest` instead of `WrapperEnum`
207+
case .known(_): break // Ok - `.known` comes from `WrapperEnum`
208+
}
209+
}
210+
}

0 commit comments

Comments
 (0)