Skip to content

[ConstraintSystem] Support solving expression patterns via injecting call to ~= operator #41589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9007,6 +9007,31 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
Type patternType = solution.simplifyType(solution.getType(info.pattern));
patternType = patternType->reconstituteSugar(/*recursive=*/false);

// Check whether this enum element is resolved via ~= application.
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
auto *EP = target->getExprPattern();
auto enumType = solution.simplifyType(EP->getType());

auto *matchCall = target->getAsExpr();

auto *result = matchCall->walk(*this);
if (!result)
return None;

{
auto *matchVar = EP->getMatchVar();
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
}

EP->setMatchExpr(result);
EP->setType(enumType);

(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
return target;
}
}

// Coerce the pattern to its appropriate type.
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
patternOptions |= TypeResolutionFlags::OverrideType;
Expand Down
87 changes: 87 additions & 0 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8629,6 +8629,69 @@ fixMemberRef(ConstraintSystem &cs, Type baseTy,
return nullptr;
}

/// Convert the given enum element pattern into an expression pattern
/// and synthesize ~= operator application to find the type of the
/// element.
static bool inferEnumMemberThroughTildeEqualsOperator(
ConstraintSystem &cs, EnumElementPattern *pattern, Type enumTy,
Type elementTy, ConstraintLocator *locator) {
if (!pattern->hasUnresolvedOriginalExpr())
return true;

auto &DC = cs.DC;
auto &ctx = cs.getASTContext();

// Slots for expression and variable are going to be filled via
// synthesizing ~= operator application.
auto *EP = new (ctx) ExprPattern(pattern->getUnresolvedOriginalExpr(),
/*matchExpr=*/nullptr, /*matchVar=*/nullptr);

auto tildeEqualsApplication =
TypeChecker::synthesizeTildeEqualsOperatorApplication(EP, DC, enumTy);

if (!tildeEqualsApplication)
return true;

VarDecl *matchVar;
Expr *matchCall;

std::tie(matchVar, matchCall) = *tildeEqualsApplication;

// result of ~= operator is always a `Bool`.
auto target = SolutionApplicationTarget::forExprPattern(
matchCall, DC, EP, ctx.getBoolDecl()->getDeclaredInterfaceType());

DiagnosticTransaction diagnostics(ctx.Diags);
{
if (cs.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
/*leaveClosureBodyUnchecked=*/false)) {
// Skip diagnostics if they are disabled, otherwise it would result in
// duplicate diagnostics, since this operation is going to be repeated
// in diagnostic mode.
if (!cs.shouldAttemptFixes())
diagnostics.abort();

return true;
}
}

cs.generateConstraints(target, FreeTypeVariableBinding::Disallow);

// Sub-expression associated with expression pattern is the enum element
// access which needs to be connected to the provided element type.
cs.addConstraint(ConstraintKind::Conversion, cs.getType(EP->getSubExpr()),
elementTy, cs.getConstraintLocator(EP));

// Store the $match variable and binary expression for solution application.
EP->setMatchVar(matchVar);
EP->setMatchExpr(matchCall);
EP->setType(enumTy);

cs.setSolutionApplicationTarget(pattern, target);

return false;
}

ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
ConstraintKind kind, Type baseTy, DeclNameRef member, Type memberTy,
DeclContext *useDC, FunctionRefKind functionRefKind,
Expand Down Expand Up @@ -8834,6 +8897,30 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
(void)recordFix(fix);
}

// If there were no results from a direct enum lookup, let's attempt
// to resolve this member via ~= operator application.
if (candidates.empty()) {
if (auto patternLoc =
locator->getLastElementAs<LocatorPathElt::PatternMatch>()) {
if (auto *enumElement =
dyn_cast<EnumElementPattern>(patternLoc->getPattern())) {
auto enumType = baseObjTy->getMetatypeInstanceType();

// If the synthesis of ~= resulted in errors (i.e. broken stdlib)
// that would be diagnosed inline, so let's just fall through and
// let this situation be diagnosed as a missing member.
auto hadErrors = inferEnumMemberThroughTildeEqualsOperator(
*this, enumElement, enumType, memberTy, locator);

// Let's consider current member constraint solved because it's
// replaced by a new set of constraints that would resolve member
// type.
if (!hadErrors)
return SolutionKind::Solved;
}
}
}

if (!candidates.empty()) {
addOverloadSet(candidates, locator);
return SolutionKind::Solved;
Expand Down
66 changes: 43 additions & 23 deletions lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,48 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
"typecheck-expr-pattern", EP);
PrettyStackTracePattern stackTrace(Context, "type-checking", EP);

auto tildeEqualsApplication =
synthesizeTildeEqualsOperatorApplication(EP, DC, rhsType);

if (!tildeEqualsApplication)
return true;

VarDecl *matchVar;
Expr *matchCall;

std::tie(matchVar, matchCall) = *tildeEqualsApplication;

// Result of `~=` should always be a boolean.
auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType();
auto target = SolutionApplicationTarget::forExprPattern(matchCall, DC, EP,
contextualTy);

// Check the expression as a condition.
auto result = typeCheckExpression(target);
if (!result)
return true;

// Save the synthesized $match variable in the pattern.
EP->setMatchVar(matchVar);
// Save the type-checked expression in the pattern.
EP->setMatchExpr(result->getAsExpr());
// Set the type on the pattern.
EP->setType(rhsType);
return false;
}

Optional<std::pair<VarDecl *, BinaryExpr *>>
TypeChecker::synthesizeTildeEqualsOperatorApplication(ExprPattern *EP,
DeclContext *DC,
Type enumType) {
auto &Context = DC->getASTContext();
// Create a 'let' binding to stand in for the RHS value.
auto *matchVar =
new (Context) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let,
EP->getLoc(), Context.Id_PatternMatchVar, DC);
matchVar->setInterfaceType(rhsType->mapTypeOutOfContext());
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());

matchVar->setImplicit();
EP->setMatchVar(matchVar);

// Find '~=' operators for the match.
auto matchLookup =
Expand All @@ -866,19 +900,19 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
auto &diags = DC->getASTContext().Diags;
if (!matchLookup) {
diags.diagnose(EP->getLoc(), diag::no_match_operator);
return true;
return None;
}

SmallVector<ValueDecl*, 4> choices;
for (auto &result : matchLookup) {
choices.push_back(result.getValueDecl());
}

if (choices.empty()) {
diags.diagnose(EP->getLoc(), diag::no_match_operator);
return true;
return None;
}

// Build the 'expr ~= var' expression.
// FIXME: Compound name locations.
auto *matchOp =
Expand All @@ -890,24 +924,10 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
auto *matchVarRef = new (Context) DeclRefExpr(matchVar,
DeclNameLoc(EP->getEndLoc()),
/*Implicit=*/true);
Expr *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp,
auto *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp,
matchVarRef, /*implicit*/ true);

// Result of `~=` should always be a boolean.
auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType();
auto target = SolutionApplicationTarget::forExprPattern(matchCall, DC, EP,
contextualTy);

// Check the expression as a condition.
auto result = typeCheckExpression(target);
if (!result)
return true;

// Save the type-checked expression in the pattern.
EP->setMatchExpr(result->getAsExpr());
// Set the type on the pattern.
EP->setType(rhsType);
return false;
return std::make_pair(matchVar, matchCall);
}

static Type replaceArchetypesWithTypeVariables(ConstraintSystem &cs,
Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,12 @@ Pattern *coercePatternToType(ContextualPattern pattern, Type type,
TypeResolutionOptions options);
bool typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, Type type);

/// Synthesize ~= operator application used to infer enum members
/// in `case` patterns.
Optional<std::pair<VarDecl *, BinaryExpr *>>
synthesizeTildeEqualsOperatorApplication(ExprPattern *EP, DeclContext *DC,
Type enumType);

/// Coerce the specified parameter list of a ClosureExpr to the specified
/// contextual type.
void coerceParameterListToType(ParameterList *P, AnyFunctionType *FN);
Expand Down
37 changes: 37 additions & 0 deletions test/Constraints/result_builder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -882,3 +882,40 @@ func test_weak_optionality_stays_the_same() {
}
}
}

enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
case known(Wrapped)

static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
switch rhs {
case .known(let wrapped):
return wrapped == lhs
}
}
}

func test_custom_tilde_equals_operator_matching() {
@resultBuilder
struct Builder {
static func buildBlock<T>(_ t: T) -> T { t }
static func buildEither<T>(first: T) -> T { first }
static func buildEither<T>(second: T) -> T { second }
}

enum TildeTest : String {
case test = "test"
}

struct S {}

struct MyView {
var entry: WrapperEnum<TildeTest>

@Builder var body: S {
switch entry {
case .test: S() // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
case .known(_): S() // Ok - `.known` comes directly from `WrapperEnum`
}
}
}
}
8 changes: 4 additions & 4 deletions test/Constraints/result_builder_diags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,8 @@ func getSomeEnumOverloaded(_: Int) -> E2 { return .b(0, nil) }

func testOverloadedSwitch() {
tuplify(true) { c in
// FIXME: Bad source location.
switch getSomeEnumOverloaded(17) { // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
case .a:
switch getSomeEnumOverloaded(17) {
case .a: // expected-error{{type 'E2' has no member 'a'; did you mean 'b'?}}
"a"
default:
"default"
Expand Down Expand Up @@ -775,13 +774,14 @@ func test_rdar65667992() {
@Builder var body: S {
switch entry { // expected-error {{type 'E' has no member 'unset'}}
case .set(_, _): S()
case .unset(_): S()
case .unset(_): S() // expected-error {{'_' can only appear in a pattern or on the left side of an assignment}}
default: S()
}
}
}
}


func test_weak_with_nonoptional_type() {
class X {
func test() -> Int { 0 }
Expand Down
30 changes: 30 additions & 0 deletions test/expr/closure/multi_statement.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,33 @@ func test_workaround_for_optional_void_result() {
let _ = $0
}
}

enum WrapperEnum<Wrapped> where Wrapped: RawRepresentable {
case known(Wrapped)

static func ~= (lhs: Wrapped, rhs: WrapperEnum<Wrapped>) -> Bool where Wrapped: Equatable {
switch rhs {
case .known(let wrapped):
return wrapped == lhs
}
}
}

func test_custom_tilde_equals_operator_matching() {
enum TildeTest : String {
case test = "test"
case otherTest = ""
}

func test(_: (WrapperEnum<TildeTest>) -> Void) {}

test { v in
print(v)

switch v {
case .test: break // Ok although `.test` comes from `TildeTest` instead of `WrapperEnum`
case .otherTest: break // Ok although `.otherTest` comes from `TildeTest` instead of `WrapperEnum`
case .known(_): break // Ok - `.known` comes from `WrapperEnum`
}
}
}